Enhance GladiaSTTService with reconnection and audio buffer management features
- Added parameters for maximum reconnection attempts, reconnection delay, and maximum audio buffer size. - Implemented automatic reconnection logic with exponential backoff. - Introduced audio buffer management to handle audio data efficiently, including trimming excess data. - Updated connection handling to ensure proper cleanup and management of WebSocket connections. - Enhanced audio sending logic to support buffered audio transmission after reconnections.
This commit is contained in:
@@ -195,6 +195,9 @@ class GladiaSTTService(STTService):
|
||||
sample_rate: Optional[int] = None,
|
||||
model: str = "solaria-1",
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
max_reconnection_attempts: int = 5,
|
||||
reconnection_delay: float = 1.0,
|
||||
max_buffer_size: int = 1024 * 1024 * 5, # 5MB default buffer
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gladia STT service.
|
||||
@@ -207,6 +210,9 @@ class GladiaSTTService(STTService):
|
||||
model: Model to use ("solaria-1", "solaria-mini-1", "fast",
|
||||
or "accurate")
|
||||
params: Additional configuration parameters
|
||||
max_reconnection_attempts: Maximum number of reconnection attempts
|
||||
reconnection_delay: Initial delay between reconnection attempts (exponential backoff)
|
||||
max_buffer_size: Maximum size of audio buffer in bytes
|
||||
**kwargs: Additional arguments passed to the STTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -232,6 +238,23 @@ class GladiaSTTService(STTService):
|
||||
self._keepalive_task = None
|
||||
self._settings = {}
|
||||
|
||||
# Reconnection settings
|
||||
self._max_reconnection_attempts = max_reconnection_attempts
|
||||
self._reconnection_delay = reconnection_delay
|
||||
self._reconnection_attempts = 0
|
||||
self._session_url = None
|
||||
self._connection_active = False
|
||||
|
||||
# Audio buffer management
|
||||
self._audio_buffer = bytearray()
|
||||
self._bytes_sent = 0
|
||||
self._max_buffer_size = max_buffer_size
|
||||
self._buffer_lock = asyncio.Lock()
|
||||
|
||||
# Connection management
|
||||
self._connection_task = None
|
||||
self._should_reconnect = True
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -293,36 +316,149 @@ class GladiaSTTService(STTService):
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Gladia STT websocket connection."""
|
||||
await super().start(frame)
|
||||
if self._websocket:
|
||||
if self._connection_task:
|
||||
return
|
||||
settings = self._prepare_settings()
|
||||
response = await self._setup_gladia(settings)
|
||||
self._websocket = await websockets.connect(response["url"])
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
if self._websocket and not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
self._should_reconnect = True
|
||||
self._connection_task = self.create_task(self._connection_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Gladia STT websocket connection."""
|
||||
await super().stop(frame)
|
||||
self._should_reconnect = False
|
||||
await self._send_stop_recording()
|
||||
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
if self._connection_task:
|
||||
await self.cancel_task(self._connection_task)
|
||||
self._connection_task = None
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
await self.wait_for_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
await self._cleanup_connection()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Gladia STT websocket connection."""
|
||||
await super().cancel(frame)
|
||||
self._should_reconnect = False
|
||||
|
||||
if self._connection_task:
|
||||
await self.cancel_task(self._connection_task)
|
||||
self._connection_task = None
|
||||
|
||||
await self._cleanup_connection()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Run speech-to-text on audio data."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
# Add audio to buffer
|
||||
async with self._buffer_lock:
|
||||
self._audio_buffer.extend(audio)
|
||||
# Trim buffer if it exceeds max size
|
||||
if len(self._audio_buffer) > self._max_buffer_size:
|
||||
trim_size = len(self._audio_buffer) - self._max_buffer_size
|
||||
self._audio_buffer = self._audio_buffer[trim_size:]
|
||||
self._bytes_sent = max(0, self._bytes_sent - trim_size)
|
||||
logger.warning(f"Audio buffer exceeded max size, trimmed {trim_size} bytes")
|
||||
|
||||
# Send audio if connected
|
||||
if self._connection_active and self._websocket and not self._websocket.closed:
|
||||
await self._send_audio(audio)
|
||||
|
||||
yield None
|
||||
|
||||
async def _connection_handler(self):
|
||||
"""Handle WebSocket connection with automatic reconnection."""
|
||||
while self._should_reconnect:
|
||||
try:
|
||||
# Initialize session if needed
|
||||
if not self._session_url:
|
||||
settings = self._prepare_settings()
|
||||
response = await self._setup_gladia(settings)
|
||||
self._session_url = response["url"]
|
||||
self._reconnection_attempts = 0
|
||||
|
||||
# Connect with automatic reconnection
|
||||
async for websocket in websockets.connect(self._session_url):
|
||||
try:
|
||||
self._websocket = websocket
|
||||
self._connection_active = True
|
||||
logger.info("Connected to Gladia WebSocket")
|
||||
|
||||
# Send buffered audio if any
|
||||
await self._send_buffered_audio()
|
||||
|
||||
# Start tasks
|
||||
receive_task = asyncio.create_task(self._receive_task_handler())
|
||||
keepalive_task = asyncio.create_task(self._keepalive_task_handler())
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(receive_task, keepalive_task)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"WebSocket connection closed: {e}")
|
||||
self._connection_active = False
|
||||
|
||||
# Clean up tasks
|
||||
if "receive_task" in locals():
|
||||
receive_task.cancel()
|
||||
if "keepalive_task" in locals():
|
||||
keepalive_task.cancel()
|
||||
|
||||
# Check if we should reconnect
|
||||
if not self._should_reconnect:
|
||||
break
|
||||
|
||||
# Implement exponential backoff
|
||||
self._reconnection_attempts += 1
|
||||
if self._reconnection_attempts > self._max_reconnection_attempts:
|
||||
logger.error(
|
||||
f"Max reconnection attempts ({self._max_reconnection_attempts}) reached"
|
||||
)
|
||||
self._should_reconnect = False
|
||||
break
|
||||
|
||||
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
|
||||
logger.info(
|
||||
f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket connection: {e}")
|
||||
self._connection_active = False
|
||||
|
||||
# Same reconnection logic as above
|
||||
if not self._should_reconnect:
|
||||
break
|
||||
|
||||
self._reconnection_attempts += 1
|
||||
if self._reconnection_attempts > self._max_reconnection_attempts:
|
||||
logger.error(
|
||||
f"Max reconnection attempts ({self._max_reconnection_attempts}) reached"
|
||||
)
|
||||
self._should_reconnect = False
|
||||
break
|
||||
|
||||
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
|
||||
logger.info(
|
||||
f"Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in connection handler: {e}")
|
||||
self._connection_active = False
|
||||
|
||||
if not self._should_reconnect:
|
||||
break
|
||||
|
||||
# Reset session URL to get a new one
|
||||
self._session_url = None
|
||||
await asyncio.sleep(self._reconnection_delay)
|
||||
|
||||
async def _cleanup_connection(self):
|
||||
"""Clean up connection resources."""
|
||||
self._connection_active = False
|
||||
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
@@ -336,13 +472,6 @@ class GladiaSTTService(STTService):
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Run speech-to-text on audio data."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._send_audio(audio)
|
||||
yield None
|
||||
|
||||
async def _setup_gladia(self, settings: Dict[str, Any]):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
@@ -369,9 +498,25 @@ class GladiaSTTService(STTService):
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
async def _send_audio(self, audio: bytes):
|
||||
data = base64.b64encode(audio).decode("utf-8")
|
||||
message = {"type": "audio_chunk", "data": {"chunk": data}}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
"""Send audio chunk with proper message format."""
|
||||
if self._websocket and not self._websocket.closed:
|
||||
data = base64.b64encode(audio).decode("utf-8")
|
||||
message = {"type": "audio_chunk", "data": {"chunk": data}}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
|
||||
async def _send_buffered_audio(self):
|
||||
"""Send any buffered audio after reconnection."""
|
||||
async with self._buffer_lock:
|
||||
if self._bytes_sent < len(self._audio_buffer):
|
||||
buffered_data = self._audio_buffer[self._bytes_sent :]
|
||||
if buffered_data:
|
||||
logger.info(f"Sending {len(buffered_data)} bytes of buffered audio")
|
||||
# Send in chunks to avoid overwhelming the connection
|
||||
chunk_size = 16384 # 16KB chunks
|
||||
for i in range(0, len(buffered_data), chunk_size):
|
||||
chunk = buffered_data[i : i + chunk_size]
|
||||
await self._send_audio(bytes(chunk))
|
||||
await asyncio.sleep(0.01) # Small delay between chunks
|
||||
|
||||
async def _send_stop_recording(self):
|
||||
if self._websocket and not self._websocket.closed:
|
||||
@@ -380,7 +525,7 @@ class GladiaSTTService(STTService):
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic empty audio chunks to keep the connection alive."""
|
||||
try:
|
||||
while True:
|
||||
while self._connection_active:
|
||||
# Send keepalive every 20 seconds (Gladia times out after 30 seconds)
|
||||
await asyncio.sleep(20)
|
||||
if self._websocket and not self._websocket.closed:
|
||||
@@ -399,7 +544,19 @@ class GladiaSTTService(STTService):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
content = json.loads(message)
|
||||
if content["type"] == "transcript":
|
||||
|
||||
# Handle audio chunk acknowledgments
|
||||
if content["type"] == "audio_chunk" and content.get("acknowledged"):
|
||||
byte_range = content["data"]["byte_range"]
|
||||
async with self._buffer_lock:
|
||||
# Update bytes sent and trim acknowledged data from buffer
|
||||
end_byte = byte_range[1]
|
||||
if end_byte > self._bytes_sent:
|
||||
trim_size = end_byte - self._bytes_sent
|
||||
self._audio_buffer = self._audio_buffer[trim_size:]
|
||||
self._bytes_sent = end_byte
|
||||
|
||||
elif content["type"] == "transcript":
|
||||
utterance = content["data"]["utterance"]
|
||||
confidence = utterance.get("confidence", 0)
|
||||
language = utterance["language"]
|
||||
|
||||
Reference in New Issue
Block a user