Fix settings update handling in additional STT services
This commit is contained in:
@@ -7,7 +7,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from deepgram import LiveOptions
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
@@ -114,7 +113,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
STTUpdateSettingsFrame(
|
||||
delta=DeepgramSageMakerSTTSettings(
|
||||
language=Language.ES,
|
||||
live_options=LiveOptions(punctuate=False),
|
||||
punctuate=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from deepgram import LiveOptions
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
@@ -108,7 +107,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
STTUpdateSettingsFrame(
|
||||
delta=DeepgramSTTSettings(
|
||||
language=Language.ES,
|
||||
live_options=LiveOptions(punctuate=False),
|
||||
punctuate=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
await asyncio.sleep(10)
|
||||
logger.info("Updating Gradium STT settings: delay_in_frames=5")
|
||||
await task.queue_frame(STTUpdateSettingsFrame(delta=GradiumSTTSettings(delay_in_frames=5)))
|
||||
await task.queue_frame(STTUpdateSettingsFrame(delta=GradiumSTTSettings(delay_in_frames=16)))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -158,22 +158,12 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
return encoding_map.get(encoding, encoding)
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta.
|
||||
|
||||
Settings are stored but not applied to the active connection.
|
||||
"""
|
||||
"""Apply a settings delta and reconnect if anything changed."""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# if changed and self._websocket:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
if changed:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
@@ -68,9 +68,16 @@ def language_to_gradium_language(language: Language) -> Optional[str]:
|
||||
|
||||
@dataclass
|
||||
class GradiumSTTSettings(STTSettings):
|
||||
"""Settings for GradiumSTTService."""
|
||||
"""Settings for GradiumSTTService.
|
||||
|
||||
pass
|
||||
Parameters:
|
||||
delay_in_frames: Delay in audio frames (80ms each) before text is
|
||||
generated. Higher delays allow more context but increase latency.
|
||||
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
|
||||
Default is 10 (800ms). Lower values like 7-8 give faster response.
|
||||
"""
|
||||
|
||||
delay_in_frames: Optional[int] = None
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
@@ -107,7 +114,6 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
delay_in_frames: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
json_config: Optional[str] = None,
|
||||
settings: Optional[GradiumSTTSettings] = None,
|
||||
@@ -119,9 +125,6 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
delay_in_frames: Delay in audio frames (80ms each) before text is
|
||||
generated. Higher delays allow more context but increase latency.
|
||||
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
|
||||
params: Configuration parameters for language and delay settings.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
@@ -151,19 +154,18 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
default_settings = GradiumSTTSettings(
|
||||
model=None,
|
||||
language=None,
|
||||
delay_in_frames=None,
|
||||
)
|
||||
|
||||
# 2. (no deprecated direct args for this service)
|
||||
|
||||
# 3. Apply params overrides — only if settings not provided
|
||||
# 2. Apply params overrides — only if settings not provided
|
||||
if params is not None:
|
||||
_warn_deprecated_param("params", GradiumSTTSettings)
|
||||
if not settings:
|
||||
default_settings.language = params.language
|
||||
if params.delay_in_frames is not None:
|
||||
delay_in_frames = params.delay_in_frames
|
||||
default_settings.delay_in_frames = params.delay_in_frames
|
||||
|
||||
# 4. Apply settings delta (canonical API, always wins)
|
||||
# 3. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
@@ -178,7 +180,6 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._json_config = json_config
|
||||
self._config_delay_in_frames = delay_in_frames
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
@@ -358,8 +359,8 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
gradium_language = language_to_gradium_language(self._settings.language)
|
||||
if gradium_language:
|
||||
json_config["language"] = gradium_language
|
||||
if self._config_delay_in_frames:
|
||||
json_config["delay_in_frames"] = self._config_delay_in_frames
|
||||
if self._settings.delay_in_frames:
|
||||
json_config["delay_in_frames"] = self._settings.delay_in_frames
|
||||
if json_config:
|
||||
setup_msg["json_config"] = json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
|
||||
@@ -907,7 +907,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
for k in ["contextCreated", "audioChunk", "flushCompleted", "contextClosed"]
|
||||
if k in result
|
||||
]
|
||||
logger.debug(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
|
||||
logger.trace(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
|
||||
|
||||
# Check for errors
|
||||
status = result.get("status", {})
|
||||
|
||||
@@ -400,12 +400,13 @@ class SarvamSTTService(STTService):
|
||||
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
# Prompt is a WebSocket connect-time parameter; reconnect to apply.
|
||||
if "prompt" in changed:
|
||||
# Language and prompt are WebSocket connect-time parameters; reconnect to apply.
|
||||
reconnect_fields = {"language", "prompt"}
|
||||
if changed.keys() & reconnect_fields:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
unhandled = {k: v for k, v in changed.items() if k != "prompt"}
|
||||
unhandled = {k: v for k, v in changed.items() if k not in reconnect_fields}
|
||||
if unhandled:
|
||||
self._warn_unhandled_updated_settings(unhandled)
|
||||
|
||||
@@ -483,7 +484,6 @@ class SarvamSTTService(STTService):
|
||||
Frame: None (transcription results come via WebSocket callbacks).
|
||||
"""
|
||||
if not self._socket_client:
|
||||
logger.warning("WebSocket not connected, cannot process audio")
|
||||
yield None
|
||||
return
|
||||
|
||||
@@ -636,18 +636,22 @@ class SarvamSTTService(STTService):
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
if self._websocket_context and self._socket_client:
|
||||
# Clear references first to prevent run_stt from sending audio
|
||||
# during the close handshake.
|
||||
socket_client = self._socket_client
|
||||
websocket_context = self._websocket_context
|
||||
self._socket_client = None
|
||||
self._websocket_context = None
|
||||
|
||||
if websocket_context and socket_client:
|
||||
try:
|
||||
# Exit the async context manager
|
||||
await self._websocket_context.__aexit__(None, None, None)
|
||||
await websocket_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Error closing WebSocket connection: {e}", exception=e
|
||||
)
|
||||
finally:
|
||||
logger.debug("Disconnected from Sarvam WebSocket")
|
||||
self._socket_client = None
|
||||
self._websocket_context = None
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming messages from Sarvam WebSocket.
|
||||
|
||||
@@ -297,9 +297,7 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
await self._connect()
|
||||
|
||||
async def _update_settings(self, delta: SonioxSTTSettings) -> dict[str, Any]:
|
||||
"""Apply settings delta.
|
||||
|
||||
Settings are stored but not applied to the active connection.
|
||||
"""Apply settings delta and reconnect if anything changed.
|
||||
|
||||
Args:
|
||||
delta: A settings delta.
|
||||
@@ -309,15 +307,9 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
if changed:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user