Compare commits

...

3 Commits

Author SHA1 Message Date
Mark Backman
d5b34759d7 Update GradiumTTSService to flush instead of ending stream 2026-01-29 18:47:17 -05:00
Mark Backman
0d697d184a Add delay_in_frames and language support 2026-01-29 18:47:00 -05:00
Mark Backman
717e1ccc01 GradiumSTTService now flushes pending transcripts on VAD stopped detection 2026-01-29 18:47:00 -05:00
6 changed files with 256 additions and 9 deletions

View File

@@ -0,0 +1,3 @@
- Updates to `GradiumSTTService`:
- Now flushes pending transcriptions when VAD detects the user stopped speaking, improving response latency.
- `GradiumSTTService` now supports `InputParams` for configuring `language` and `delay_in_frames` settings.

1
changelog/3596.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed an issue in `GradiumTTSService` where the websocket was being disconnected at the end of every bot turn.

View File

@@ -26,6 +26,7 @@ from pipecat.runner.utils import create_transport
from pipecat.services.gradium.stt import GradiumSTTService
from pipecat.services.gradium.tts import GradiumTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transcriptions.language import Language
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -59,11 +60,18 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
stt = GradiumSTTService(
api_key=os.getenv("GRADIUM_API_KEY"),
api_endpoint_base_url="wss://us.api.gradium.ai/api/speech/asr",
params=GradiumSTTService.InputParams(
language=Language.EN,
),
)
tts = GradiumTTSService(
api_key=os.getenv("GRADIUM_API_KEY"),
voice_id="YTpq7expH9539ERJ",
url="wss://us.api.gradium.ai/api/speech/tts",
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))

View File

@@ -0,0 +1,86 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.frames.frames import Frame, TranscriptionFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.gradium.stt import GradiumSTTService
from pipecat.transcriptions.language import Language
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
class TranscriptionLogger(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
print(f"Transcription: {frame.text}")
# Push all frames through
await self.push_frame(frame, direction)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(audio_in_enabled=True),
"twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True),
"webrtc": lambda: TransportParams(audio_in_enabled=True),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = GradiumSTTService(
api_key=os.getenv("GRADIUM_API_KEY"),
api_endpoint_base_url="wss://us.api.gradium.ai/api/speech/asr",
params=GradiumSTTService.InputParams(language=Language.EN, delay_in_frames=8),
)
tl = TranscriptionLogger()
pipeline = Pipeline([transport.input(), stt, tl])
task = PipelineTask(
pipeline,
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -12,9 +12,10 @@ WebSocket API for streaming audio transcription.
import base64
import json
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
@@ -22,9 +23,12 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TranscriptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
@@ -39,6 +43,26 @@ except ModuleNotFoundError as e:
SAMPLE_RATE = 24000
def language_to_gradium_language(language: Language) -> Optional[str]:
"""Convert a Language enum to Gradium's language code format.
Args:
language: The Language enum value to convert.
Returns:
The Gradium language code string or None if not supported.
"""
LANGUAGE_MAP = {
Language.DE: "de",
Language.EN: "en",
Language.ES: "es",
Language.FR: "fr",
Language.PT: "pt",
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
class GradiumSTTService(WebsocketSTTService):
"""Gradium real-time speech-to-text service.
@@ -47,12 +71,29 @@ class GradiumSTTService(WebsocketSTTService):
for audio processing and connection management.
"""
class InputParams(BaseModel):
"""Configuration parameters for Gradium STT API.
Parameters:
language: Expected language of the audio (e.g., "en", "es", "fr").
This helps ground the model to a specific language and improve
transcription quality.
delay_in_frames: Delay in audio frames (80ms each) before text is
generated. Higher delays allow more context but increase latency.
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
Default is 10 (800ms). Lower values like 7-8 give faster response.
"""
language: Optional[Language] = None
delay_in_frames: Optional[int] = None
def __init__(
self,
*,
api_key: str,
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
json_config: str | None = None,
params: Optional[InputParams] = None,
json_config: Optional[str] = None,
**kwargs,
):
"""Initialize the Gradium STT service.
@@ -60,14 +101,29 @@ class GradiumSTTService(WebsocketSTTService):
Args:
api_key: Gradium API key for authentication.
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
params: Configuration parameters for language and delay settings.
json_config: Optional JSON configuration string for additional model settings.
.. deprecated:: 0.0.101
Use `params` instead for type-safe configuration.
**kwargs: Additional arguments passed to parent STTService class.
"""
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
if json_config is not None:
import warnings
warnings.warn(
"Parameter 'json_config' is deprecated and will be removed in a future version, use 'params' instead.",
DeprecationWarning,
stacklevel=2,
)
self._api_key = api_key
self._api_endpoint_base_url = api_endpoint_base_url
self._websocket = None
self._params = params or GradiumSTTService.InputParams()
self._json_config = json_config
self._receive_task = None
@@ -76,6 +132,11 @@ class GradiumSTTService(WebsocketSTTService):
self._chunk_size_ms = 80
self._chunk_size_bytes = 0
# Set from the ready message when connecting to the service.
# These values are used for flushing transcription.
self._delay_in_frames = 0
self._frame_size = 0
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
@@ -84,6 +145,17 @@ class GradiumSTTService(WebsocketSTTService):
"""
return True
async def set_language(self, language: Language):
"""Set the recognition language and reconnect.
Args:
language: The language to use for speech recognition.
"""
logger.info(f"Switching STT language to: [{language}]")
self._params.language = language
await self._disconnect()
await self._connect()
async def start(self, frame: StartFrame):
"""Start the speech-to-text service.
@@ -112,6 +184,57 @@ class GradiumSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with VAD-specific handling.
When VAD detects the user has stopped speaking, we flush the transcription
by sending silence frames. This makes the system more reactive by getting
the final transcription faster without closing the connection.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, VADUserStartedSpeakingFrame):
await self.start_processing_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
await self._flush_transcription()
async def _flush_transcription(self):
"""Flush the transcription by sending silence frames.
When VAD detects the user stopped speaking, we send delay_in_frames
chunks of silence (zeros) to flush the remaining audio from the model's
buffer. This allows for faster turn-around without closing the connection.
From Gradium docs: "feed in delay_in_frames chunks of silence (vectors
of zeros). If those are fed in faster than realtime, the API also has
a possibility to process them faster."
"""
if not self._websocket or self._websocket.state is not State.OPEN:
return
if self._delay_in_frames <= 0:
logger.debug("No delay_in_frames set, skipping flush")
return
# Create a silence chunk (zeros) of frame_size samples
# Each sample is 2 bytes (16-bit PCM)
silence_bytes = bytes(self._frame_size * 2)
silence_b64 = base64.b64encode(silence_bytes).decode("utf-8")
logger.debug(f"Flushing Gradium STT with {self._delay_in_frames} silence frames")
for _ in range(self._delay_in_frames):
msg = {"type": "audio", "audio": silence_b64}
try:
await self._websocket.send(json.dumps(msg))
except Exception as e:
logger.warning(f"Failed to send silence frame: {e}")
break
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text conversion.
@@ -122,7 +245,6 @@ class GradiumSTTService(WebsocketSTTService):
None (processing handled via WebSocket messages).
"""
self._audio_buffer.extend(audio)
await self.start_processing_metrics()
while len(self._audio_buffer) >= self._chunk_size_bytes:
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
@@ -151,6 +273,9 @@ class GradiumSTTService(WebsocketSTTService):
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to Gradium STT")
ws_url = self._api_endpoint_base_url
headers = {
"x-api-key": self._api_key,
@@ -165,8 +290,18 @@ class GradiumSTTService(WebsocketSTTService):
"type": "setup",
"input_format": "pcm",
}
if self._json_config is not None:
setup_msg["json_config"] = self._json_config
# Build json_config: start with deprecated json_config, then override with params
json_config = {}
if self._json_config:
json_config = json.loads(self._json_config)
if self._params.language:
gradium_language = language_to_gradium_language(self._params.language)
if gradium_language:
json_config["language"] = gradium_language
if self._params.delay_in_frames:
json_config["delay_in_frames"] = self._params.delay_in_frames
if json_config:
setup_msg["json_config"] = json_config
await self._websocket.send(json.dumps(setup_msg))
ready_msg = await self._websocket.recv()
ready_msg = json.loads(ready_msg)
@@ -175,6 +310,14 @@ class GradiumSTTService(WebsocketSTTService):
if ready_msg["type"] != "ready":
raise Exception(f"unexpected first message type {ready_msg['type']}")
# Store delay_in_frames and frame_size for silence flushing
self._delay_in_frames = ready_msg.get("delay_in_frames", 0)
self._frame_size = ready_msg.get("frame_size", 1920)
logger.debug(
f"Connected to Gradium STT (delay_in_frames={self._delay_in_frames}, "
f"frame_size={self._frame_size})"
)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise
@@ -240,3 +383,5 @@ class GradiumSTTService(WebsocketSTTService):
time_now_iso8601(),
)
)
await self._trace_transcription(text, is_final=True, language=None)
await self.stop_processing_metrics()

View File

@@ -232,11 +232,15 @@ class GradiumTTSService(InterruptibleWordTTSService):
raise Exception("Websocket not connected")
async def flush_audio(self):
"""Flush any pending audio synthesis."""
"""Flush any pending audio synthesis.
Sends a <flush> tag to force the model to output audio for all text
that has been input so far, without closing the connection.
"""
if not self._websocket:
return
try:
msg = {"type": "end_of_stream"}
msg = {"type": "text", "text": "<flush>"}
await self._websocket.send(json.dumps(msg))
except ConnectionClosedOK:
logger.debug(f"{self}: connection closed normally during flush")