Compare commits
3 Commits
hush/conte
...
mb/gradium
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5b34759d7 | ||
|
|
0d697d184a | ||
|
|
717e1ccc01 |
3
changelog/3587.changed.md
Normal file
3
changelog/3587.changed.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Updates to `GradiumSTTService`:
|
||||
- Now flushes pending transcriptions when VAD detects the user stopped speaking, improving response latency.
|
||||
- `GradiumSTTService` now supports `InputParams` for configuring `language` and `delay_in_frames` settings.
|
||||
1
changelog/3596.fixed.md
Normal file
1
changelog/3596.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue in `GradiumTTSService` where the websocket was being disconnected at the end of every bot turn.
|
||||
@@ -26,6 +26,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.services.gradium.tts import GradiumTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -59,11 +60,18 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
|
||||
stt = GradiumSTTService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
api_endpoint_base_url="wss://us.api.gradium.ai/api/speech/asr",
|
||||
params=GradiumSTTService.InputParams(
|
||||
language=Language.EN,
|
||||
),
|
||||
)
|
||||
|
||||
tts = GradiumTTSService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
voice_id="YTpq7expH9539ERJ",
|
||||
url="wss://us.api.gradium.ai/api/speech/tts",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
86
examples/foundational/13l-gradium-transcription.py
Normal file
86
examples/foundational/13l-gradium-transcription.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import Frame, TranscriptionFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class TranscriptionLogger(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
api_endpoint_base_url="wss://us.api.gradium.ai/api/speech/asr",
|
||||
params=GradiumSTTService.InputParams(language=Language.EN, delay_in_frames=8),
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
pipeline = Pipeline([transport.input(), stt, tl])
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -12,9 +12,10 @@ WebSocket API for streaming audio transcription.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -22,9 +23,12 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
@@ -39,6 +43,26 @@ except ModuleNotFoundError as e:
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
def language_to_gradium_language(language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Gradium's language code format.
|
||||
|
||||
Args:
|
||||
language: The Language enum value to convert.
|
||||
|
||||
Returns:
|
||||
The Gradium language code string or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
Language.DE: "de",
|
||||
Language.EN: "en",
|
||||
Language.ES: "es",
|
||||
Language.FR: "fr",
|
||||
Language.PT: "pt",
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
"""Gradium real-time speech-to-text service.
|
||||
|
||||
@@ -47,12 +71,29 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
for audio processing and connection management.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Gradium STT API.
|
||||
|
||||
Parameters:
|
||||
language: Expected language of the audio (e.g., "en", "es", "fr").
|
||||
This helps ground the model to a specific language and improve
|
||||
transcription quality.
|
||||
delay_in_frames: Delay in audio frames (80ms each) before text is
|
||||
generated. Higher delays allow more context but increase latency.
|
||||
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
|
||||
Default is 10 (800ms). Lower values like 7-8 give faster response.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
delay_in_frames: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
json_config: str | None = None,
|
||||
params: Optional[InputParams] = None,
|
||||
json_config: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
@@ -60,14 +101,29 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
params: Configuration parameters for language and delay settings.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
|
||||
.. deprecated:: 0.0.101
|
||||
Use `params` instead for type-safe configuration.
|
||||
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
|
||||
if json_config is not None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Parameter 'json_config' is deprecated and will be removed in a future version, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._params = params or GradiumSTTService.InputParams()
|
||||
self._json_config = json_config
|
||||
|
||||
self._receive_task = None
|
||||
@@ -76,6 +132,11 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
self._chunk_size_ms = 80
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
# Set from the ready message when connecting to the service.
|
||||
# These values are used for flushing transcription.
|
||||
self._delay_in_frames = 0
|
||||
self._frame_size = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
@@ -84,6 +145,17 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the recognition language and reconnect.
|
||||
|
||||
Args:
|
||||
language: The language to use for speech recognition.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._params.language = language
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech-to-text service.
|
||||
|
||||
@@ -112,6 +184,57 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with VAD-specific handling.
|
||||
|
||||
When VAD detects the user has stopped speaking, we flush the transcription
|
||||
by sending silence frames. This makes the system more reactive by getting
|
||||
the final transcription faster without closing the connection.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self.start_processing_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._flush_transcription()
|
||||
|
||||
async def _flush_transcription(self):
|
||||
"""Flush the transcription by sending silence frames.
|
||||
|
||||
When VAD detects the user stopped speaking, we send delay_in_frames
|
||||
chunks of silence (zeros) to flush the remaining audio from the model's
|
||||
buffer. This allows for faster turn-around without closing the connection.
|
||||
|
||||
From Gradium docs: "feed in delay_in_frames chunks of silence (vectors
|
||||
of zeros). If those are fed in faster than realtime, the API also has
|
||||
a possibility to process them faster."
|
||||
"""
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
return
|
||||
|
||||
if self._delay_in_frames <= 0:
|
||||
logger.debug("No delay_in_frames set, skipping flush")
|
||||
return
|
||||
|
||||
# Create a silence chunk (zeros) of frame_size samples
|
||||
# Each sample is 2 bytes (16-bit PCM)
|
||||
silence_bytes = bytes(self._frame_size * 2)
|
||||
silence_b64 = base64.b64encode(silence_bytes).decode("utf-8")
|
||||
|
||||
logger.debug(f"Flushing Gradium STT with {self._delay_in_frames} silence frames")
|
||||
|
||||
for _ in range(self._delay_in_frames):
|
||||
msg = {"type": "audio", "audio": silence_b64}
|
||||
try:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send silence frame: {e}")
|
||||
break
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text conversion.
|
||||
|
||||
@@ -122,7 +245,6 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
None (processing handled via WebSocket messages).
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
await self.start_processing_metrics()
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
@@ -151,6 +273,9 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to Gradium STT")
|
||||
|
||||
ws_url = self._api_endpoint_base_url
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
@@ -165,8 +290,18 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
"type": "setup",
|
||||
"input_format": "pcm",
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
# Build json_config: start with deprecated json_config, then override with params
|
||||
json_config = {}
|
||||
if self._json_config:
|
||||
json_config = json.loads(self._json_config)
|
||||
if self._params.language:
|
||||
gradium_language = language_to_gradium_language(self._params.language)
|
||||
if gradium_language:
|
||||
json_config["language"] = gradium_language
|
||||
if self._params.delay_in_frames:
|
||||
json_config["delay_in_frames"] = self._params.delay_in_frames
|
||||
if json_config:
|
||||
setup_msg["json_config"] = json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
@@ -175,6 +310,14 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
# Store delay_in_frames and frame_size for silence flushing
|
||||
self._delay_in_frames = ready_msg.get("delay_in_frames", 0)
|
||||
self._frame_size = ready_msg.get("frame_size", 1920)
|
||||
logger.debug(
|
||||
f"Connected to Gradium STT (delay_in_frames={self._delay_in_frames}, "
|
||||
f"frame_size={self._frame_size})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
@@ -240,3 +383,5 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
await self._trace_transcription(text, is_final=True, language=None)
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
@@ -232,11 +232,15 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
"""Flush any pending audio synthesis.
|
||||
|
||||
Sends a <flush> tag to force the model to output audio for all text
|
||||
that has been input so far, without closing the connection.
|
||||
"""
|
||||
if not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
msg = {"type": "text", "text": "<flush>"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
|
||||
Reference in New Issue
Block a user