Enhance GladiaSTTService with reconnection and audio buffer management features

- Added parameters for maximum reconnection attempts, reconnection delay, and maximum audio buffer size.
- Implemented automatic reconnection logic with exponential backoff.
- Introduced audio buffer management to handle audio data efficiently, including trimming excess data.
- Updated connection handling to ensure proper cleanup and management of WebSocket connections.
- Enhanced audio sending logic to support buffered audio transmission after reconnections.
This commit is contained in:
jqueguiner
2025-06-03 03:16:57 -07:00
parent a8aaeec52b
commit 02cc6f3d56

View File

@@ -195,6 +195,9 @@ class GladiaSTTService(STTService):
sample_rate: Optional[int] = None,
model: str = "solaria-1",
params: Optional[GladiaInputParams] = None,
max_reconnection_attempts: int = 5,
reconnection_delay: float = 1.0,
max_buffer_size: int = 1024 * 1024 * 5, # 5MB default buffer
**kwargs,
):
"""Initialize the Gladia STT service.
@@ -207,6 +210,9 @@ class GladiaSTTService(STTService):
model: Model to use ("solaria-1", "solaria-mini-1", "fast",
or "accurate")
params: Additional configuration parameters
max_reconnection_attempts: Maximum number of reconnection attempts
reconnection_delay: Initial delay between reconnection attempts (exponential backoff)
max_buffer_size: Maximum size of audio buffer in bytes
**kwargs: Additional arguments passed to the STTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -232,6 +238,23 @@ class GladiaSTTService(STTService):
self._keepalive_task = None
self._settings = {}
# Reconnection settings
self._max_reconnection_attempts = max_reconnection_attempts
self._reconnection_delay = reconnection_delay
self._reconnection_attempts = 0
self._session_url = None
self._connection_active = False
# Audio buffer management
self._audio_buffer = bytearray()
self._bytes_sent = 0
self._max_buffer_size = max_buffer_size
self._buffer_lock = asyncio.Lock()
# Connection management
self._connection_task = None
self._should_reconnect = True
def can_generate_metrics(self) -> bool:
return True
@@ -293,36 +316,149 @@ class GladiaSTTService(STTService):
async def start(self, frame: StartFrame):
"""Start the Gladia STT websocket connection."""
await super().start(frame)
if self._websocket:
if self._connection_task:
return
settings = self._prepare_settings()
response = await self._setup_gladia(settings)
self._websocket = await websockets.connect(response["url"])
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler())
if self._websocket and not self._keepalive_task:
self._keepalive_task = self.create_task(self._keepalive_task_handler())
self._should_reconnect = True
self._connection_task = self.create_task(self._connection_handler())
async def stop(self, frame: EndFrame):
"""Stop the Gladia STT websocket connection."""
await super().stop(frame)
self._should_reconnect = False
await self._send_stop_recording()
if self._keepalive_task:
await self.cancel_task(self._keepalive_task)
self._keepalive_task = None
if self._connection_task:
await self.cancel_task(self._connection_task)
self._connection_task = None
if self._websocket:
await self._websocket.close()
self._websocket = None
if self._receive_task:
await self.wait_for_task(self._receive_task)
self._receive_task = None
await self._cleanup_connection()
async def cancel(self, frame: CancelFrame):
"""Cancel the Gladia STT websocket connection."""
await super().cancel(frame)
self._should_reconnect = False
if self._connection_task:
await self.cancel_task(self._connection_task)
self._connection_task = None
await self._cleanup_connection()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Run speech-to-text on audio data."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
# Add audio to buffer
async with self._buffer_lock:
self._audio_buffer.extend(audio)
# Trim buffer if it exceeds max size
if len(self._audio_buffer) > self._max_buffer_size:
trim_size = len(self._audio_buffer) - self._max_buffer_size
self._audio_buffer = self._audio_buffer[trim_size:]
self._bytes_sent = max(0, self._bytes_sent - trim_size)
logger.warning(f"Audio buffer exceeded max size, trimmed {trim_size} bytes")
# Send audio if connected
if self._connection_active and self._websocket and not self._websocket.closed:
await self._send_audio(audio)
yield None
async def _connection_handler(self):
"""Handle WebSocket connection with automatic reconnection."""
while self._should_reconnect:
try:
# Initialize session if needed
if not self._session_url:
settings = self._prepare_settings()
response = await self._setup_gladia(settings)
self._session_url = response["url"]
self._reconnection_attempts = 0
# Connect with automatic reconnection
async for websocket in websockets.connect(self._session_url):
try:
self._websocket = websocket
self._connection_active = True
logger.info("Connected to Gladia WebSocket")
# Send buffered audio if any
await self._send_buffered_audio()
# Start tasks
receive_task = asyncio.create_task(self._receive_task_handler())
keepalive_task = asyncio.create_task(self._keepalive_task_handler())
# Wait for tasks to complete
await asyncio.gather(receive_task, keepalive_task)
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"WebSocket connection closed: {e}")
self._connection_active = False
# Clean up tasks
if "receive_task" in locals():
receive_task.cancel()
if "keepalive_task" in locals():
keepalive_task.cancel()
# Check if we should reconnect
if not self._should_reconnect:
break
# Implement exponential backoff
self._reconnection_attempts += 1
if self._reconnection_attempts > self._max_reconnection_attempts:
logger.error(
f"Max reconnection attempts ({self._max_reconnection_attempts}) reached"
)
self._should_reconnect = False
break
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
logger.info(
f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})"
)
await asyncio.sleep(delay)
except Exception as e:
logger.error(f"Error in WebSocket connection: {e}")
self._connection_active = False
# Same reconnection logic as above
if not self._should_reconnect:
break
self._reconnection_attempts += 1
if self._reconnection_attempts > self._max_reconnection_attempts:
logger.error(
f"Max reconnection attempts ({self._max_reconnection_attempts}) reached"
)
self._should_reconnect = False
break
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
logger.info(
f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})"
)
await asyncio.sleep(delay)
except Exception as e:
logger.error(f"Error in connection handler: {e}")
self._connection_active = False
if not self._should_reconnect:
break
# Reset session URL to get a new one
self._session_url = None
await asyncio.sleep(self._reconnection_delay)
async def _cleanup_connection(self):
"""Clean up connection resources."""
self._connection_active = False
if self._keepalive_task:
await self.cancel_task(self._keepalive_task)
@@ -336,13 +472,6 @@ class GladiaSTTService(STTService):
await self.cancel_task(self._receive_task)
self._receive_task = None
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Run speech-to-text on audio data."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._send_audio(audio)
yield None
async def _setup_gladia(self, settings: Dict[str, Any]):
async with aiohttp.ClientSession() as session:
async with session.post(
@@ -369,9 +498,25 @@ class GladiaSTTService(STTService):
await self.stop_processing_metrics()
async def _send_audio(self, audio: bytes):
data = base64.b64encode(audio).decode("utf-8")
message = {"type": "audio_chunk", "data": {"chunk": data}}
await self._websocket.send(json.dumps(message))
"""Send audio chunk with proper message format."""
if self._websocket and not self._websocket.closed:
data = base64.b64encode(audio).decode("utf-8")
message = {"type": "audio_chunk", "data": {"chunk": data}}
await self._websocket.send(json.dumps(message))
async def _send_buffered_audio(self):
"""Send any buffered audio after reconnection."""
async with self._buffer_lock:
if self._bytes_sent < len(self._audio_buffer):
buffered_data = self._audio_buffer[self._bytes_sent :]
if buffered_data:
logger.info(f"Sending {len(buffered_data)} bytes of buffered audio")
# Send in chunks to avoid overwhelming the connection
chunk_size = 16384 # 16KB chunks
for i in range(0, len(buffered_data), chunk_size):
chunk = buffered_data[i : i + chunk_size]
await self._send_audio(bytes(chunk))
await asyncio.sleep(0.01) # Small delay between chunks
async def _send_stop_recording(self):
if self._websocket and not self._websocket.closed:
@@ -380,7 +525,7 @@ class GladiaSTTService(STTService):
async def _keepalive_task_handler(self):
"""Send periodic empty audio chunks to keep the connection alive."""
try:
while True:
while self._connection_active:
# Send keepalive every 20 seconds (Gladia times out after 30 seconds)
await asyncio.sleep(20)
if self._websocket and not self._websocket.closed:
@@ -399,7 +544,19 @@ class GladiaSTTService(STTService):
try:
async for message in self._websocket:
content = json.loads(message)
if content["type"] == "transcript":
# Handle audio chunk acknowledgments
if content["type"] == "audio_chunk" and content.get("acknowledged"):
byte_range = content["data"]["byte_range"]
async with self._buffer_lock:
# Update bytes sent and trim acknowledged data from buffer
end_byte = byte_range[1]
if end_byte > self._bytes_sent:
trim_size = end_byte - self._bytes_sent
self._audio_buffer = self._audio_buffer[trim_size:]
self._bytes_sent = end_byte
elif content["type"] == "transcript":
utterance = content["data"]["utterance"]
confidence = utterance.get("confidence", 0)
language = utterance["language"]