Refactoring how we are reconnecting the STT.

This commit is contained in:
filipi87
2026-05-12 18:20:19 -03:00
parent 7984556692
commit 5d61763987

View File

@@ -194,19 +194,31 @@ class NvidiaSageMakerSTTService(STTService):
# ── Connection management ─────────────────────────────────────────────────
async def _open_client_session(self):
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_query_string=None,
model_invocation_path=None,
)
await self._client.start_session()
await self._send_session_config()
async def _close_client_session(self):
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "session.end"})
except Exception as e:
logger.warning(f"{self}: error sending session.end: {e}")
await self._client.close_session()
self._client = None
async def _connect(self):
logger.debug(
f"{self}: connecting to SageMaker bidi-stream endpoint '{self._endpoint_name}'"
)
try:
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_query_string=None,
model_invocation_path=None,
)
await self._client.start_session()
await self._send_session_config()
await self._open_client_session()
self._response_task = self.create_task(self._process_responses())
logger.debug(f"{self}: connected")
await self._call_event_handler("on_connected")
@@ -219,19 +231,13 @@ class NvidiaSageMakerSTTService(STTService):
if self._response_task and not self._response_task.done():
await self.cancel_task(self._response_task)
self._response_task = None
if self._client and self._client.is_active:
logger.debug(f"{self}: disconnecting")
try:
await self._client.send_json({"type": "session.end"})
except Exception as e:
logger.warning(f"{self}: error sending session.end: {e}")
await self._client.close_session()
logger.debug(f"{self}: disconnected")
self._client = None
await self._close_client_session()
await self._call_event_handler("on_disconnected")
async def _do_reconnect(self):
await self._close_client_session()
await self._open_client_session()
async def _send_session_config(self):
"""Send transcription_session.update to configure audio format and params.
@@ -333,8 +339,7 @@ class NvidiaSageMakerSTTService(STTService):
):
await self.push_error(error_msg=f"NIM ASR error: {msg}")
# In case of error we need to reconnect, otherwise we are not going to receive from the STT service anymore
await self._disconnect()
await self._connect()
await self._request_reconnect()
except asyncio.CancelledError:
logger.debug(f"{self}: response processor cancelled")