From 5d6176398752babe6ae2b5b6178efb8c31d97502 Mon Sep 17 00:00:00 2001 From: filipi87 Date: Tue, 12 May 2026 18:20:19 -0300 Subject: [PATCH] Refactoring how we are reconnecting the STT. --- src/pipecat/services/nvidia/sagemaker/stt.py | 47 +++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/pipecat/services/nvidia/sagemaker/stt.py b/src/pipecat/services/nvidia/sagemaker/stt.py index 34b6d7519..cf6c8c6d8 100644 --- a/src/pipecat/services/nvidia/sagemaker/stt.py +++ b/src/pipecat/services/nvidia/sagemaker/stt.py @@ -194,19 +194,31 @@ class NvidiaSageMakerSTTService(STTService): # ── Connection management ───────────────────────────────────────────────── + async def _open_client_session(self): + self._client = SageMakerBidiClient( + endpoint_name=self._endpoint_name, + region=self._region, + model_query_string=None, + model_invocation_path=None, + ) + await self._client.start_session() + await self._send_session_config() + + async def _close_client_session(self): + if self._client and self._client.is_active: + try: + await self._client.send_json({"type": "session.end"}) + except Exception as e: + logger.warning(f"{self}: error sending session.end: {e}") + await self._client.close_session() + self._client = None + async def _connect(self): logger.debug( f"{self}: connecting to SageMaker bidi-stream endpoint '{self._endpoint_name}'" ) try: - self._client = SageMakerBidiClient( - endpoint_name=self._endpoint_name, - region=self._region, - model_query_string=None, - model_invocation_path=None, - ) - await self._client.start_session() - await self._send_session_config() + await self._open_client_session() self._response_task = self.create_task(self._process_responses()) logger.debug(f"{self}: connected") await self._call_event_handler("on_connected") @@ -219,19 +231,13 @@ class NvidiaSageMakerSTTService(STTService): if self._response_task and not self._response_task.done(): await self.cancel_task(self._response_task) self._response_task = None - - if self._client and self._client.is_active: - logger.debug(f"{self}: disconnecting") - try: - await self._client.send_json({"type": "session.end"}) - except Exception as e: - logger.warning(f"{self}: error sending session.end: {e}") - await self._client.close_session() - logger.debug(f"{self}: disconnected") - - self._client = None + await self._close_client_session() await self._call_event_handler("on_disconnected") + async def _do_reconnect(self): + await self._close_client_session() + await self._open_client_session() + async def _send_session_config(self): """Send transcription_session.update to configure audio format and params. @@ -333,8 +339,7 @@ class NvidiaSageMakerSTTService(STTService): ): await self.push_error(error_msg=f"NIM ASR error: {msg}") # In case of error we need to reconnect, otherwise we are not going to receive from the STT service anymore - await self._disconnect() - await self._connect() + await self._request_reconnect() except asyncio.CancelledError: logger.debug(f"{self}: response processor cancelled")