Fix settings update handling in additional STT services

This commit is contained in:
Mark Backman
2026-03-06 21:52:45 -05:00
parent 9c42d27f4d
commit ec93cd1d51
8 changed files with 40 additions and 55 deletions

View File

@@ -7,7 +7,6 @@
import asyncio
import os
from deepgram import LiveOptions
from dotenv import load_dotenv
from loguru import logger
@@ -114,7 +113,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
STTUpdateSettingsFrame(
delta=DeepgramSageMakerSTTSettings(
language=Language.ES,
live_options=LiveOptions(punctuate=False),
punctuate=False,
)
)
)

View File

@@ -7,7 +7,6 @@
import asyncio
import os
from deepgram import LiveOptions
from dotenv import load_dotenv
from loguru import logger
@@ -108,7 +107,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
STTUpdateSettingsFrame(
delta=DeepgramSTTSettings(
language=Language.ES,
live_options=LiveOptions(punctuate=False),
punctuate=False,
)
)
)

View File

@@ -104,7 +104,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
await asyncio.sleep(10)
logger.info("Updating Gradium STT settings: delay_in_frames=5")
await task.queue_frame(STTUpdateSettingsFrame(delta=GradiumSTTSettings(delay_in_frames=5)))
await task.queue_frame(STTUpdateSettingsFrame(delta=GradiumSTTSettings(delay_in_frames=16)))
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):

View File

@@ -158,22 +158,12 @@ class AWSTranscribeSTTService(WebsocketSTTService):
return encoding_map.get(encoding, encoding)
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply a settings delta.
Settings are stored but not applied to the active connection.
"""
"""Apply a settings delta and reconnect if anything changed."""
changed = await super()._update_settings(delta)
if not changed:
return changed
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# if changed and self._websocket:
# await self._disconnect()
# await self._connect()
self._warn_unhandled_updated_settings(changed)
if changed:
await self._disconnect()
await self._connect()
return changed

View File

@@ -68,9 +68,16 @@ def language_to_gradium_language(language: Language) -> Optional[str]:
@dataclass
class GradiumSTTSettings(STTSettings):
"""Settings for GradiumSTTService."""
"""Settings for GradiumSTTService.
pass
Parameters:
delay_in_frames: Delay in audio frames (80ms each) before text is
generated. Higher delays allow more context but increase latency.
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
Default is 10 (800ms). Lower values like 7-8 give faster response.
"""
delay_in_frames: Optional[int] = None
class GradiumSTTService(WebsocketSTTService):
@@ -107,7 +114,6 @@ class GradiumSTTService(WebsocketSTTService):
*,
api_key: str,
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
delay_in_frames: Optional[int] = None,
params: Optional[InputParams] = None,
json_config: Optional[str] = None,
settings: Optional[GradiumSTTSettings] = None,
@@ -119,9 +125,6 @@ class GradiumSTTService(WebsocketSTTService):
Args:
api_key: Gradium API key for authentication.
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
delay_in_frames: Delay in audio frames (80ms each) before text is
generated. Higher delays allow more context but increase latency.
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
params: Configuration parameters for language and delay settings.
.. deprecated:: 0.0.105
@@ -151,19 +154,18 @@ class GradiumSTTService(WebsocketSTTService):
default_settings = GradiumSTTSettings(
model=None,
language=None,
delay_in_frames=None,
)
# 2. (no deprecated direct args for this service)
# 3. Apply params overrides — only if settings not provided
# 2. Apply params overrides — only if settings not provided
if params is not None:
_warn_deprecated_param("params", GradiumSTTSettings)
if not settings:
default_settings.language = params.language
if params.delay_in_frames is not None:
delay_in_frames = params.delay_in_frames
default_settings.delay_in_frames = params.delay_in_frames
# 4. Apply settings delta (canonical API, always wins)
# 3. Apply settings delta (canonical API, always wins)
if settings is not None:
default_settings.apply_update(settings)
@@ -178,7 +180,6 @@ class GradiumSTTService(WebsocketSTTService):
self._api_endpoint_base_url = api_endpoint_base_url
self._websocket = None
self._json_config = json_config
self._config_delay_in_frames = delay_in_frames
self._receive_task = None
@@ -358,8 +359,8 @@ class GradiumSTTService(WebsocketSTTService):
gradium_language = language_to_gradium_language(self._settings.language)
if gradium_language:
json_config["language"] = gradium_language
if self._config_delay_in_frames:
json_config["delay_in_frames"] = self._config_delay_in_frames
if self._settings.delay_in_frames:
json_config["delay_in_frames"] = self._settings.delay_in_frames
if json_config:
setup_msg["json_config"] = json_config
await self._websocket.send(json.dumps(setup_msg))

View File

@@ -907,7 +907,7 @@ class InworldTTSService(WebsocketTTSService):
for k in ["contextCreated", "audioChunk", "flushCompleted", "contextClosed"]
if k in result
]
logger.debug(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
logger.trace(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
# Check for errors
status = result.get("status", {})

View File

@@ -400,12 +400,13 @@ class SarvamSTTService(STTService):
changed = await super()._update_settings(delta)
# Prompt is a WebSocket connect-time parameter; reconnect to apply.
if "prompt" in changed:
# Language and prompt are WebSocket connect-time parameters; reconnect to apply.
reconnect_fields = {"language", "prompt"}
if changed.keys() & reconnect_fields:
await self._disconnect()
await self._connect()
unhandled = {k: v for k, v in changed.items() if k != "prompt"}
unhandled = {k: v for k, v in changed.items() if k not in reconnect_fields}
if unhandled:
self._warn_unhandled_updated_settings(unhandled)
@@ -483,7 +484,6 @@ class SarvamSTTService(STTService):
Frame: None (transcription results come via WebSocket callbacks).
"""
if not self._socket_client:
logger.warning("WebSocket not connected, cannot process audio")
yield None
return
@@ -636,18 +636,22 @@ class SarvamSTTService(STTService):
await self.cancel_task(self._receive_task)
self._receive_task = None
if self._websocket_context and self._socket_client:
# Clear references first to prevent run_stt from sending audio
# during the close handshake.
socket_client = self._socket_client
websocket_context = self._websocket_context
self._socket_client = None
self._websocket_context = None
if websocket_context and socket_client:
try:
# Exit the async context manager
await self._websocket_context.__aexit__(None, None, None)
await websocket_context.__aexit__(None, None, None)
except Exception as e:
await self.push_error(
error_msg=f"Error closing WebSocket connection: {e}", exception=e
)
finally:
logger.debug("Disconnected from Sarvam WebSocket")
self._socket_client = None
self._websocket_context = None
async def _receive_task_handler(self):
"""Handle incoming messages from Sarvam WebSocket.

View File

@@ -297,9 +297,7 @@ class SonioxSTTService(WebsocketSTTService):
await self._connect()
async def _update_settings(self, delta: SonioxSTTSettings) -> dict[str, Any]:
"""Apply settings delta.
Settings are stored but not applied to the active connection.
"""Apply settings delta and reconnect if anything changed.
Args:
delta: A settings delta.
@@ -309,15 +307,9 @@ class SonioxSTTService(WebsocketSTTService):
"""
changed = await super()._update_settings(delta)
if not changed:
return changed
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# await self._disconnect()
# await self._connect()
self._warn_unhandled_updated_settings(changed)
if changed:
await self._disconnect()
await self._connect()
return changed