diff --git a/CHANGELOG.md b/CHANGELOG.md index 433555890..acd680d5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `on_connected` and `on_disconnected` events to TTS and STT + websocket-based services. + - Added an `aggregate_sentences` arg in `ElevenLabsHttpTTSService`, where the default value is True. diff --git a/src/pipecat/services/assemblyai/stt.py b/src/pipecat/services/assemblyai/stt.py index aa2fc36bc..381d60506 100644 --- a/src/pipecat/services/assemblyai/stt.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -197,6 +197,8 @@ class AssemblyAISTTService(STTService): ) self._connected = True self._receive_task = self.create_task(self._receive_task_handler()) + + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"Failed to connect to AssemblyAI: {e}") self._connected = False @@ -238,6 +240,7 @@ class AssemblyAISTTService(STTService): self._websocket = None self._connected = False self._receive_task = None + await self._call_event_handler("on_disconnected") async def _receive_task_handler(self): """Handle incoming WebSocket messages.""" diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index b0ddf9275..3e4ff33cc 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -235,6 +235,8 @@ class AsyncAITTSService(InterruptibleTTSService): } await self._get_websocket().send(json.dumps(init_msg)) + + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -252,6 +254,7 @@ class AsyncAITTSService(InterruptibleTTSService): finally: self._websocket = None self._started = False + await self._call_event_handler("on_disconnected") def _get_websocket(self): if self._websocket: diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index 59cda5865..b019fc058 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -286,6 +286,7 @@ class AWSTranscribeSTTService(STTService): logger.info(f"{self} Successfully connected to AWS Transcribe") + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} Failed to connect to AWS Transcribe: {e}") await self._disconnect() @@ -310,6 +311,7 @@ class AWSTranscribeSTTService(STTService): logger.warning(f"{self} Error closing WebSocket connection: {e}") finally: self._ws_client = None + await self._call_event_handler("on_disconnected") def language_to_service_language(self, language: Language) -> str | None: """Convert internal language enum to AWS Transcribe language code. diff --git a/src/pipecat/services/cartesia/stt.py b/src/pipecat/services/cartesia/stt.py index 97a4f7127..b4e232c4a 100644 --- a/src/pipecat/services/cartesia/stt.py +++ b/src/pipecat/services/cartesia/stt.py @@ -273,6 +273,7 @@ class CartesiaSTTService(WebsocketSTTService): headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key} self._websocket = await websocket_connect(ws_url, additional_headers=headers) + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self}: unable to connect to Cartesia: {e}") @@ -285,6 +286,7 @@ class CartesiaSTTService(WebsocketSTTService): logger.error(f"{self} error closing websocket: {e}") finally: self._websocket = None + await self._call_event_handler("on_disconnected") def _get_websocket(self): if self._websocket: diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index 9b2475cb6..90f0ac3b4 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -348,6 +348,7 @@ class CartesiaTTSService(AudioContextWordTTSService): self._websocket = await websocket_connect( f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}" ) + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -365,6 +366,7 @@ class CartesiaTTSService(AudioContextWordTTSService): finally: self._context_id = None self._websocket = None + await self._call_event_handler("on_disconnected") def _get_websocket(self): if self._websocket: diff --git a/src/pipecat/services/deepgram/flux/stt.py b/src/pipecat/services/deepgram/flux/stt.py index 493bece80..f0b1a5baa 100644 --- a/src/pipecat/services/deepgram/flux/stt.py +++ b/src/pipecat/services/deepgram/flux/stt.py @@ -205,6 +205,7 @@ class DeepgramFluxSTTService(WebsocketSTTService): additional_headers={"Authorization": f"Token {self._api_key}"}, ) logger.debug("Connected to Deepgram Flux Websocket") + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -225,6 +226,9 @@ class DeepgramFluxSTTService(WebsocketSTTService): await self._websocket.close() except Exception as e: logger.error(f"{self} error closing websocket: {e}") + finally: + self._websocket = None + await self._call_event_handler("on_disconnected") async def _send_close_stream(self) -> None: """Sends a CloseStream control message to the Deepgram Flux WebSocket API. diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index 080f54cd6..641f50aa1 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -528,6 +528,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService): url, max_size=16 * 1024 * 1024, additional_headers={"xi-api-key": self._api_key} ) + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -550,6 +551,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService): self._started = False self._context_id = None self._websocket = None + await self._call_event_handler("on_disconnected") def _get_websocket(self): if self._websocket: diff --git a/src/pipecat/services/fish/tts.py b/src/pipecat/services/fish/tts.py index b39b775e5..669d2ce97 100644 --- a/src/pipecat/services/fish/tts.py +++ b/src/pipecat/services/fish/tts.py @@ -225,6 +225,8 @@ class FishAudioTTSService(InterruptibleTTSService): start_message = {"event": "start", "request": {"text": "", **self._settings}} await self._websocket.send(ormsgpack.packb(start_message)) logger.debug("Sent start message to Fish Audio") + + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"Fish Audio initialization error: {e}") self._websocket = None @@ -245,6 +247,7 @@ class FishAudioTTSService(InterruptibleTTSService): self._request_id = None self._started = False self._websocket = None + await self._call_event_handler("on_disconnected") async def flush_audio(self): """Flush any buffered audio by sending a flush event to Fish Audio.""" diff --git a/src/pipecat/services/google/stt.py b/src/pipecat/services/google/stt.py index 31ae597f7..b9e56f55b 100644 --- a/src/pipecat/services/google/stt.py +++ b/src/pipecat/services/google/stt.py @@ -730,6 +730,8 @@ class GoogleSTTService(STTService): self._request_queue = asyncio.Queue() self._streaming_task = self.create_task(self._stream_audio()) + await self._call_event_handler("on_connected") + async def _disconnect(self): """Clean up streaming recognition resources.""" if self._streaming_task: @@ -737,6 +739,8 @@ class GoogleSTTService(STTService): await self.cancel_task(self._streaming_task) self._streaming_task = None + await self._call_event_handler("on_disconnected") + async def _request_generator(self): """Generates requests for the streaming recognize method.""" recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_" diff --git a/src/pipecat/services/lmnt/tts.py b/src/pipecat/services/lmnt/tts.py index a602789fd..9f9fef5fc 100644 --- a/src/pipecat/services/lmnt/tts.py +++ b/src/pipecat/services/lmnt/tts.py @@ -222,6 +222,7 @@ class LmntTTSService(InterruptibleTTSService): # Send initialization message await self._websocket.send(json.dumps(init_msg)) + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -243,6 +244,7 @@ class LmntTTSService(InterruptibleTTSService): finally: self._started = False self._websocket = None + await self._call_event_handler("on_disconnected") def _get_websocket(self): """Get the WebSocket connection if available.""" diff --git a/src/pipecat/services/neuphonic/tts.py b/src/pipecat/services/neuphonic/tts.py index 46d805086..6ccdfe17f 100644 --- a/src/pipecat/services/neuphonic/tts.py +++ b/src/pipecat/services/neuphonic/tts.py @@ -293,6 +293,8 @@ class NeuphonicTTSService(InterruptibleTTSService): headers = {"x-api-key": self._api_key} self._websocket = await websocket_connect(url, additional_headers=headers) + + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -311,6 +313,7 @@ class NeuphonicTTSService(InterruptibleTTSService): finally: self._started = False self._websocket = None + await self._call_event_handler("on_disconnected") async def _receive_messages(self): """Receive and process messages from Neuphonic WebSocket.""" diff --git a/src/pipecat/services/playht/tts.py b/src/pipecat/services/playht/tts.py index 9288ebd59..925480794 100644 --- a/src/pipecat/services/playht/tts.py +++ b/src/pipecat/services/playht/tts.py @@ -269,6 +269,8 @@ class PlayHTTTSService(InterruptibleTTSService): raise ValueError("WebSocket URL is not a string") self._websocket = await websocket_connect(self._websocket_url) + + await self._call_event_handler("on_connected") except ValueError as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -291,6 +293,7 @@ class PlayHTTTSService(InterruptibleTTSService): finally: self._request_id = None self._websocket = None + await self._call_event_handler("on_disconnected") async def _get_websocket_url(self): """Retrieve WebSocket URL from PlayHT API.""" diff --git a/src/pipecat/services/rime/tts.py b/src/pipecat/services/rime/tts.py index 1ac829ebd..fa3fa447d 100644 --- a/src/pipecat/services/rime/tts.py +++ b/src/pipecat/services/rime/tts.py @@ -255,6 +255,8 @@ class RimeTTSService(AudioContextWordTTSService): url = f"{self._url}?{params}" headers = {"Authorization": f"Bearer {self._api_key}"} self._websocket = await websocket_connect(url, additional_headers=headers) + + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -272,6 +274,7 @@ class RimeTTSService(AudioContextWordTTSService): finally: self._context_id = None self._websocket = None + await self._call_event_handler("on_disconnected") def _get_websocket(self): """Get active websocket connection or raise exception.""" diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index a9fedcc58..75e6de125 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -525,6 +525,7 @@ class SarvamTTSService(InterruptibleTTSService): logger.debug("Connected to Sarvam TTS Websocket") await self._send_config() + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -556,6 +557,10 @@ class SarvamTTSService(InterruptibleTTSService): await self._websocket.close() except Exception as e: logger.error(f"{self} error closing websocket: {e}") + finally: + self._started = False + self._websocket = None + await self._call_event_handler("on_disconnected") def _get_websocket(self): if self._websocket: diff --git a/src/pipecat/services/speechmatics/stt.py b/src/pipecat/services/speechmatics/stt.py index 4028dd248..2c1db2a15 100644 --- a/src/pipecat/services/speechmatics/stt.py +++ b/src/pipecat/services/speechmatics/stt.py @@ -577,6 +577,7 @@ class SpeechmaticsSTTService(STTService): ), ) logger.debug(f"{self} Connected to Speechmatics STT service") + await self._call_event_handler("on_connected") except Exception as e: logger.error(f"{self} Error connecting to Speechmatics: {e}") self._client = None @@ -595,6 +596,7 @@ class SpeechmaticsSTTService(STTService): logger.error(f"{self} Error closing Speechmatics client: {e}") finally: self._client = None + await self._call_event_handler("on_disconnected") def _process_config(self) -> None: """Create a formatted STT transcription config. diff --git a/src/pipecat/services/stt_service.py b/src/pipecat/services/stt_service.py index a02619e44..6fb96c571 100644 --- a/src/pipecat/services/stt_service.py +++ b/src/pipecat/services/stt_service.py @@ -35,6 +35,25 @@ class STTService(AIService): Provides common functionality for STT services including audio passthrough, muting, settings management, and audio processing. Subclasses must implement the run_stt method to provide actual speech recognition. + + Event handlers: + on_connected: Called when connected to the STT service. + on_connected: Called when disconnected from the STT service. + on_connection_error: Called when a connection to the STT service error occurs. + + Example:: + + @stt.event_handler("on_connected") + async def on_connected(stt: STTService): + logger.debug(f"STT connected") + + @stt.event_handler("on_disconnected") + async def on_disconnected(stt: STTService): + logger.debug(f"STT disconnected") + + @stt.event_handler("on_connection_error") + async def on_connection_error(stt: STTService, error: str): + logger.error(f"STT connection error: {error}") """ def __init__( @@ -62,6 +81,10 @@ class STTService(AIService): self._muted: bool = False self._user_id: str = "" + self._register_event_handler("on_connected") + self._register_event_handler("on_disconnected") + self._register_event_handler("on_connection_error") + @property def is_muted(self) -> bool: """Check if the STT service is currently muted. @@ -292,15 +315,6 @@ class WebsocketSTTService(STTService, WebsocketService): Combines STT functionality with websocket connectivity, providing automatic error handling and reconnection capabilities. - - Event handlers: - on_connection_error: Called when a websocket connection error occurs. - - Example:: - - @stt.event_handler("on_connection_error") - async def on_connection_error(stt: STTService, error: str): - logger.error(f"STT connection error: {error}") """ def __init__(self, *, reconnect_on_error: bool = True, **kwargs): @@ -312,7 +326,6 @@ class WebsocketSTTService(STTService, WebsocketService): """ STTService.__init__(self, **kwargs) WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) - self._register_event_handler("on_connection_error") async def _report_error(self, error: ErrorFrame): await self._call_event_handler("on_connection_error", error.error) diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index a60b50818..b356c7244 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -59,6 +59,25 @@ class TTSService(AIService): Provides common functionality for TTS services including text aggregation, filtering, audio generation, and frame management. Supports configurable sentence aggregation, silence insertion, and frame processing control. + + Event handlers: + on_connected: Called when connected to the STT service. + on_connected: Called when disconnected from the STT service. + on_connection_error: Called when a connection to the STT service error occurs. + + Example:: + + @tts.event_handler("on_connected") + async def on_connected(tts: TTSService): + logger.debug(f"TTS connected") + + @tts.event_handler("on_disconnected") + async def on_disconnected(tts: TTSService): + logger.debug(f"TTS disconnected") + + @tts.event_handler("on_connection_error") + async def on_connection_error(stt: TTSService, error: str): + logger.error(f"TTS connection error: {error}") """ def __init__( @@ -143,6 +162,10 @@ class TTSService(AIService): self._processing_text: bool = False + self._register_event_handler("on_connected") + self._register_event_handler("on_disconnected") + self._register_event_handler("on_connection_error") + @property def sample_rate(self) -> int: """Get the current sample rate for audio output. @@ -626,7 +649,6 @@ class WebsocketTTSService(TTSService, WebsocketService): """ TTSService.__init__(self, **kwargs) WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) - self._register_event_handler("on_connection_error") async def _report_error(self, error: ErrorFrame): await self._call_event_handler("on_connection_error", error.error) @@ -678,15 +700,6 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService): """Base class for websocket-based TTS services that support word timestamps. Combines word timestamp functionality with websocket connectivity. - - Event handlers: - on_connection_error: Called when a websocket connection error occurs. - - Example:: - - @tts.event_handler("on_connection_error") - async def on_connection_error(tts: TTSService, error: str): - logger.error(f"TTS connection error: {error}") """ def __init__(self, *, reconnect_on_error: bool = True, **kwargs): @@ -698,7 +711,6 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService): """ WordTTSService.__init__(self, **kwargs) WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) - self._register_event_handler("on_connection_error") async def _report_error(self, error: ErrorFrame): await self._call_event_handler("on_connection_error", error.error)