services: added on_connected/on_disconnected events
This commit is contained in:
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_connected` and `on_disconnected` events to TTS and STT
|
||||
websocket-based services.
|
||||
|
||||
- Added an `aggregate_sentences` arg in `ElevenLabsHttpTTSService`, where the
|
||||
default value is True.
|
||||
|
||||
|
||||
@@ -197,6 +197,8 @@ class AssemblyAISTTService(STTService):
|
||||
)
|
||||
self._connected = True
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to AssemblyAI: {e}")
|
||||
self._connected = False
|
||||
@@ -238,6 +240,7 @@ class AssemblyAISTTService(STTService):
|
||||
self._websocket = None
|
||||
self._connected = False
|
||||
self._receive_task = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
|
||||
@@ -235,6 +235,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
}
|
||||
|
||||
await self._get_websocket().send(json.dumps(init_msg))
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -252,6 +254,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._started = False
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -286,6 +286,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
await self._disconnect()
|
||||
@@ -310,6 +311,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
self._ws_client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert internal language enum to AWS Transcribe language code.
|
||||
|
||||
@@ -273,6 +273,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}
|
||||
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to connect to Cartesia: {e}")
|
||||
|
||||
@@ -285,6 +286,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -348,6 +348,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
self._websocket = await websocket_connect(
|
||||
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -365,6 +366,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -205,6 +205,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -225,6 +226,9 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_close_stream(self) -> None:
|
||||
"""Sends a CloseStream control message to the Deepgram Flux WebSocket API.
|
||||
|
||||
@@ -528,6 +528,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
url, max_size=16 * 1024 * 1024, additional_headers={"xi-api-key": self._api_key}
|
||||
)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -550,6 +551,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -225,6 +225,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
start_message = {"event": "start", "request": {"text": "", **self._settings}}
|
||||
await self._websocket.send(ormsgpack.packb(start_message))
|
||||
logger.debug("Sent start message to Fish Audio")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Fish Audio initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -245,6 +247,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any buffered audio by sending a flush event to Fish Audio."""
|
||||
|
||||
@@ -730,6 +730,8 @@ class GoogleSTTService(STTService):
|
||||
self._request_queue = asyncio.Queue()
|
||||
self._streaming_task = self.create_task(self._stream_audio())
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Clean up streaming recognition resources."""
|
||||
if self._streaming_task:
|
||||
@@ -737,6 +739,8 @@ class GoogleSTTService(STTService):
|
||||
await self.cancel_task(self._streaming_task)
|
||||
self._streaming_task = None
|
||||
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _request_generator(self):
|
||||
"""Generates requests for the streaming recognize method."""
|
||||
recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_"
|
||||
|
||||
@@ -222,6 +222,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
# Send initialization message
|
||||
await self._websocket.send(json.dumps(init_msg))
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -243,6 +244,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get the WebSocket connection if available."""
|
||||
|
||||
@@ -293,6 +293,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
headers = {"x-api-key": self._api_key}
|
||||
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -311,6 +313,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process messages from Neuphonic WebSocket."""
|
||||
|
||||
@@ -269,6 +269,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
raise ValueError("WebSocket URL is not a string")
|
||||
|
||||
self._websocket = await websocket_connect(self._websocket_url)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except ValueError as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -291,6 +293,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _get_websocket_url(self):
|
||||
"""Retrieve WebSocket URL from PlayHT API."""
|
||||
|
||||
@@ -255,6 +255,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
url = f"{self._url}?{params}"
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -272,6 +274,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
|
||||
@@ -525,6 +525,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
logger.debug("Connected to Sarvam TTS Websocket")
|
||||
await self._send_config()
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -556,6 +557,10 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -577,6 +577,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
),
|
||||
)
|
||||
logger.debug(f"{self} Connected to Speechmatics STT service")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error connecting to Speechmatics: {e}")
|
||||
self._client = None
|
||||
@@ -595,6 +596,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
logger.error(f"{self} Error closing Speechmatics client: {e}")
|
||||
finally:
|
||||
self._client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _process_config(self) -> None:
|
||||
"""Create a formatted STT transcription config.
|
||||
|
||||
@@ -35,6 +35,25 @@ class STTService(AIService):
|
||||
Provides common functionality for STT services including audio passthrough,
|
||||
muting, settings management, and audio processing. Subclasses must implement
|
||||
the run_stt method to provide actual speech recognition.
|
||||
|
||||
Event handlers:
|
||||
on_connected: Called when connected to the STT service.
|
||||
on_connected: Called when disconnected from the STT service.
|
||||
on_connection_error: Called when a connection to the STT service error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_connected")
|
||||
async def on_connected(stt: STTService):
|
||||
logger.debug(f"STT connected")
|
||||
|
||||
@stt.event_handler("on_disconnected")
|
||||
async def on_disconnected(stt: STTService):
|
||||
logger.debug(f"STT disconnected")
|
||||
|
||||
@stt.event_handler("on_connection_error")
|
||||
async def on_connection_error(stt: STTService, error: str):
|
||||
logger.error(f"STT connection error: {error}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -62,6 +81,10 @@ class STTService(AIService):
|
||||
self._muted: bool = False
|
||||
self._user_id: str = ""
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
@property
|
||||
def is_muted(self) -> bool:
|
||||
"""Check if the STT service is currently muted.
|
||||
@@ -292,15 +315,6 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
|
||||
Combines STT functionality with websocket connectivity, providing automatic
|
||||
error handling and reconnection capabilities.
|
||||
|
||||
Event handlers:
|
||||
on_connection_error: Called when a websocket connection error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_connection_error")
|
||||
async def on_connection_error(stt: STTService, error: str):
|
||||
logger.error(f"STT connection error: {error}")
|
||||
"""
|
||||
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
@@ -312,7 +326,6 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
"""
|
||||
STTService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
|
||||
@@ -59,6 +59,25 @@ class TTSService(AIService):
|
||||
Provides common functionality for TTS services including text aggregation,
|
||||
filtering, audio generation, and frame management. Supports configurable
|
||||
sentence aggregation, silence insertion, and frame processing control.
|
||||
|
||||
Event handlers:
|
||||
on_connected: Called when connected to the STT service.
|
||||
on_connected: Called when disconnected from the STT service.
|
||||
on_connection_error: Called when a connection to the STT service error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@tts.event_handler("on_connected")
|
||||
async def on_connected(tts: TTSService):
|
||||
logger.debug(f"TTS connected")
|
||||
|
||||
@tts.event_handler("on_disconnected")
|
||||
async def on_disconnected(tts: TTSService):
|
||||
logger.debug(f"TTS disconnected")
|
||||
|
||||
@tts.event_handler("on_connection_error")
|
||||
async def on_connection_error(stt: TTSService, error: str):
|
||||
logger.error(f"TTS connection error: {error}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -143,6 +162,10 @@ class TTSService(AIService):
|
||||
|
||||
self._processing_text: bool = False
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the current sample rate for audio output.
|
||||
@@ -626,7 +649,6 @@ class WebsocketTTSService(TTSService, WebsocketService):
|
||||
"""
|
||||
TTSService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
@@ -678,15 +700,6 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
|
||||
"""Base class for websocket-based TTS services that support word timestamps.
|
||||
|
||||
Combines word timestamp functionality with websocket connectivity.
|
||||
|
||||
Event handlers:
|
||||
on_connection_error: Called when a websocket connection error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@tts.event_handler("on_connection_error")
|
||||
async def on_connection_error(tts: TTSService, error: str):
|
||||
logger.error(f"TTS connection error: {error}")
|
||||
"""
|
||||
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
@@ -698,7 +711,6 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
|
||||
"""
|
||||
WordTTSService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
|
||||
Reference in New Issue
Block a user