Add reconnect logic to handle Google's 5 min time limit
This commit is contained in:
@@ -9,6 +9,7 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
@@ -1421,6 +1422,12 @@ class GoogleSTTService(STTService):
|
||||
InputParams: Configuration parameters for the STT service.
|
||||
"""
|
||||
|
||||
# Google Cloud's STT service has a connection time limit of 5 minutes per stream.
|
||||
# They've shared an "endless streaming" example that guided this implementation:
|
||||
# https://cloud.google.com/speech-to-text/docs/transcribe-streaming-audio#endless-streaming
|
||||
|
||||
STREAMING_LIMIT = 240000 # 4 minutes in milliseconds
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Google Speech-to-Text.
|
||||
|
||||
@@ -1495,6 +1502,18 @@ class GoogleSTTService(STTService):
|
||||
self._request_queue = asyncio.Queue()
|
||||
self._streaming_task = None
|
||||
|
||||
# Used for keep-alive logic
|
||||
self._stream_start_time = 0
|
||||
self._last_audio_input = []
|
||||
self._audio_input = []
|
||||
self._result_end_time = 0
|
||||
self._is_final_end_time = 0
|
||||
self._final_request_end_time = 0
|
||||
self._bridging_offset = 0
|
||||
self._last_transcript_was_final = False
|
||||
self._new_stream = True
|
||||
self._restart_counter = 0
|
||||
|
||||
# Configure client options based on location
|
||||
client_options = None
|
||||
if self._location != "global":
|
||||
@@ -1687,7 +1706,10 @@ class GoogleSTTService(STTService):
|
||||
"""Initialize streaming recognition config and stream."""
|
||||
logger.debug("Connecting to Google Speech-to-Text")
|
||||
|
||||
# Create recognition config with explicit audio format
|
||||
# Set stream start time
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
self._new_stream = True
|
||||
|
||||
self._config = cloud_speech.StreamingRecognitionConfig(
|
||||
config=cloud_speech.RecognitionConfig(
|
||||
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
|
||||
@@ -1736,24 +1758,37 @@ class GoogleSTTService(STTService):
|
||||
logger.trace(f"Using recognizer path: {recognizer_path}")
|
||||
|
||||
try:
|
||||
# First, send the recognition config
|
||||
config_request = cloud_speech.StreamingRecognizeRequest(
|
||||
# Send initial config
|
||||
yield cloud_speech.StreamingRecognizeRequest(
|
||||
recognizer=recognizer_path,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
yield config_request
|
||||
|
||||
# Then send all audio data requests
|
||||
while True:
|
||||
try:
|
||||
audio_data = await self._request_queue.get()
|
||||
if audio_data is None: # Sentinel value to stop
|
||||
break
|
||||
|
||||
# Check streaming limit
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Streaming limit reached, initiating graceful reconnection")
|
||||
# Instead of immediate reconnection, we'll break and let the stream close naturally
|
||||
self._last_audio_input = self._audio_input
|
||||
self._audio_input = []
|
||||
self._restart_counter += 1
|
||||
# Put the current audio chunk back in the queue
|
||||
await self._request_queue.put(audio_data)
|
||||
break
|
||||
|
||||
self._audio_input.append(audio_data)
|
||||
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
self._request_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in request generator: {e}")
|
||||
raise
|
||||
@@ -1761,16 +1796,31 @@ class GoogleSTTService(STTService):
|
||||
async def _stream_audio(self):
|
||||
"""Handle bi-directional streaming with Google STT."""
|
||||
try:
|
||||
# Start bi-directional streaming
|
||||
streaming_recognize = await self._client.streaming_recognize(
|
||||
requests=self._request_generator()
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
# Start bi-directional streaming
|
||||
streaming_recognize = await self._client.streaming_recognize(
|
||||
requests=self._request_generator()
|
||||
)
|
||||
|
||||
# Process responses using task manager
|
||||
response_task = self.create_task(self._process_responses(streaming_recognize))
|
||||
# Process responses
|
||||
await self._process_responses(streaming_recognize)
|
||||
|
||||
# Wait for the response processing to complete
|
||||
await self.wait_for_task(response_task)
|
||||
# If we're here, check if we need to reconnect
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Reconnecting stream after timeout")
|
||||
# Reset stream start time
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
continue
|
||||
else:
|
||||
# Normal stream end
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error, attempting to reconnect: {e}")
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming task: {e}")
|
||||
@@ -1787,6 +1837,11 @@ class GoogleSTTService(STTService):
|
||||
"""Process streaming recognition responses."""
|
||||
try:
|
||||
async for response in streaming_recognize:
|
||||
# Check streaming limit
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Stream timeout reached in response processing")
|
||||
break
|
||||
|
||||
if not response.results:
|
||||
continue
|
||||
|
||||
@@ -1798,14 +1853,15 @@ class GoogleSTTService(STTService):
|
||||
if not transcript:
|
||||
continue
|
||||
|
||||
# Use the primary language (first in the list)
|
||||
primary_language = self._settings["language_codes"][0]
|
||||
|
||||
if result.is_final:
|
||||
self._last_transcript_was_final = True
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language)
|
||||
)
|
||||
else:
|
||||
self._last_transcript_was_final = False
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript, "", time_now_iso8601(), primary_language
|
||||
@@ -1814,4 +1870,3 @@ class GoogleSTTService(STTService):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google STT responses: {e}")
|
||||
await self.push_frame(ErrorFrame(str(e)))
|
||||
|
||||
Reference in New Issue
Block a user