Refactoring how we are reconnecting the STT.
This commit is contained in:
@@ -194,19 +194,31 @@ class NvidiaSageMakerSTTService(STTService):
|
||||
|
||||
# ── Connection management ─────────────────────────────────────────────────
|
||||
|
||||
async def _open_client_session(self):
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_query_string=None,
|
||||
model_invocation_path=None,
|
||||
)
|
||||
await self._client.start_session()
|
||||
await self._send_session_config()
|
||||
|
||||
async def _close_client_session(self):
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "session.end"})
|
||||
except Exception as e:
|
||||
logger.warning(f"{self}: error sending session.end: {e}")
|
||||
await self._client.close_session()
|
||||
self._client = None
|
||||
|
||||
async def _connect(self):
|
||||
logger.debug(
|
||||
f"{self}: connecting to SageMaker bidi-stream endpoint '{self._endpoint_name}'"
|
||||
)
|
||||
try:
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_query_string=None,
|
||||
model_invocation_path=None,
|
||||
)
|
||||
await self._client.start_session()
|
||||
await self._send_session_config()
|
||||
await self._open_client_session()
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
logger.debug(f"{self}: connected")
|
||||
await self._call_event_handler("on_connected")
|
||||
@@ -219,19 +231,13 @@ class NvidiaSageMakerSTTService(STTService):
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug(f"{self}: disconnecting")
|
||||
try:
|
||||
await self._client.send_json({"type": "session.end"})
|
||||
except Exception as e:
|
||||
logger.warning(f"{self}: error sending session.end: {e}")
|
||||
await self._client.close_session()
|
||||
logger.debug(f"{self}: disconnected")
|
||||
|
||||
self._client = None
|
||||
await self._close_client_session()
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _do_reconnect(self):
|
||||
await self._close_client_session()
|
||||
await self._open_client_session()
|
||||
|
||||
async def _send_session_config(self):
|
||||
"""Send transcription_session.update to configure audio format and params.
|
||||
|
||||
@@ -333,8 +339,7 @@ class NvidiaSageMakerSTTService(STTService):
|
||||
):
|
||||
await self.push_error(error_msg=f"NIM ASR error: {msg}")
|
||||
# In case of error we need to reconnect, otherwise we are not going to receive from the STT service anymore
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
await self._request_reconnect()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self}: response processor cancelled")
|
||||
|
||||
Reference in New Issue
Block a user