Refactoring the services to use push_error and push_error_frame

This commit is contained in:
Filipi Fuchter
2025-11-18 18:43:30 -03:00
parent 50bef86d33
commit fdf3c8b4cf
27 changed files with 48 additions and 94 deletions

View File

@@ -240,8 +240,7 @@ class AssemblyAISTTService(STTService):
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
self._websocket = None
@@ -293,8 +292,7 @@ class AssemblyAISTTService(STTService):
elif isinstance(parsed_message, TerminationMessage):
await self._handle_termination(parsed_message)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
async def _handle_termination(self, message: TerminationMessage):
"""Handle termination message."""

View File

@@ -240,8 +240,7 @@ class AsyncAITTSService(InterruptibleTTSService):
logger.debug("Disconnecting from Async")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
self._websocket = None
self._started = False
@@ -476,8 +475,7 @@ class AsyncAIHttpTTSService(TTSService):
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Async API error: {error_text}")
await self.push_error(ErrorFrame(error=f"Async API error: {error_text}"))
await self.push_error(error_msg=f"Async API error: {error_text}")
raise Exception(f"Async API returned status {response.status}: {error_text}")
audio_data = await response.read()

View File

@@ -288,8 +288,7 @@ class AWSTranscribeSTTService(STTService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
await self._disconnect()
raise
@@ -536,6 +535,5 @@ class AWSTranscribeSTTService(STTService):
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
break

View File

@@ -284,8 +284,7 @@ class CartesiaSTTService(WebsocketSTTService):
logger.debug("Disconnecting from Cartesia STT")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"{self} error closing websocket: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")

View File

@@ -409,8 +409,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
logger.debug("Disconnecting from Cartesia")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
self._context_id = None
self._websocket = None

View File

@@ -250,8 +250,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
logger.debug("Connected to Deepgram Flux Websocket")
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")

View File

@@ -424,8 +424,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
json.dumps({"context_id": self._context_id, "close_context": True})
)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._context_id = None
self._started = False
@@ -553,8 +552,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._websocket.close()
logger.debug("Disconnected from ElevenLabs")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
self._started = False
self._context_id = None

View File

@@ -284,8 +284,7 @@ class FishAudioTTSService(InterruptibleTTSService):
continue
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:

View File

@@ -468,8 +468,7 @@ class GladiaSTTService(STTService):
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._connection_active = False
if not self._should_reconnect:
@@ -623,8 +622,7 @@ class GladiaSTTService(STTService):
# Expected when closing the connection
pass
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
async def _maybe_reconnect(self) -> bool:
"""Handle exponential backoff reconnection logic."""

View File

@@ -1174,7 +1174,7 @@ class GeminiLiveLLMService(LLMService):
self._connection_task = self.create_task(self._connection_task_handler(config=config))
except Exception as e:
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}"))
await self.push_error(exception=e)
async def _connection_task_handler(self, config: LiveConnectConfig):
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:

View File

@@ -774,8 +774,7 @@ class GoogleSTTService(STTService):
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
raise
async def _stream_audio(self):
@@ -813,8 +812,7 @@ class GoogleSTTService(STTService):
self._stream_start_time = int(time.time() * 1000)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process an audio chunk for STT transcription.

View File

@@ -216,8 +216,7 @@ class HumeTTSService(TTSService):
self._audio_bytes = b""
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
# Ensure TTFB timer is stopped even on early failures
await self.stop_ttfb_metrics()

View File

@@ -392,8 +392,7 @@ class InworldTTSService(TTSService):
# STEP 7: ERROR HANDLING
# ================================================================================
# Log any unexpected errors and notify the pipeline
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
# ================================================================================
# STEP 8: CLEANUP AND COMPLETION

View File

@@ -214,8 +214,7 @@ class LmntTTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -266,10 +265,9 @@ class LmntTTSService(InterruptibleTTSService):
try:
msg = json.loads(message)
if "error" in msg:
logger.error(f"{self} error: {msg['error']}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
await self.push_error(error_msg=f"{self} error: {msg['error']}")
return
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")

View File

@@ -285,8 +285,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")

View File

@@ -478,7 +478,7 @@ class OpenAIRealtimeLLMService(LLMService):
# it is to recover from a send-side error with proper state management, and that exponential
# backoff for retries can have cost/stability implications for a service cluster, let's just
# treat a send-side error as fatal.
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
async def _update_settings(self):
settings = self._session_properties
@@ -759,7 +759,7 @@ class OpenAIRealtimeLLMService(LLMService):
async def _handle_evt_error(self, evt):
# Errors are fatal to this connection. Send an ErrorFrame.
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
await self.push_error(error_msg=f"Error: {evt}")
#
# state and client events for the current conversation

View File

@@ -454,7 +454,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
# it is to recover from a send-side error with proper state management, and that exponential
# backoff for retries can have cost/stability implications for a service cluster, let's just
# treat a send-side error as fatal.
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
async def _update_settings(self):
settings = self._session_properties
@@ -685,7 +685,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_error(self, evt):
# Errors are fatal to this connection. Send an ErrorFrame.
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
await self.push_error(error_msg=f"Error: {evt}")
async def _handle_assistant_output(self, output):
# We haven't seen intermixed audio and function_call items in the same response. But let's

View File

@@ -266,8 +266,7 @@ class PlayHTTTSService(InterruptibleTTSService):
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -352,7 +351,7 @@ class PlayHTTTSService(InterruptibleTTSService):
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
await self.push_error(error_msg=f"{self} error: {msg['error']}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")

View File

@@ -259,8 +259,7 @@ class RimeTTSService(AudioContextWordTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -366,10 +365,9 @@ class RimeTTSService(AudioContextWordTTSService):
logger.debug(f"Updated cumulative time to: {self._cumulative_time}")
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
await self.push_error(error_msg=f"{self} error: {msg['message']}")
self._context_id = None
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):

View File

@@ -275,8 +275,7 @@ class SarvamSTTService(STTService):
await self._socket_client.translate(**method_kwargs)
except Exception as e:
logger.error(f"Error sending audio to Sarvam: {e}")
await self.push_error(ErrorFrame(f"Failed to send audio: {e}"))
await self.push_error(error_msg=f"Error sending audio to Sarvam: {e}", exception=e)
yield None
@@ -332,13 +331,11 @@ class SarvamSTTService(STTService):
logger.info("Connected to Sarvam successfully")
except ApiError as e:
logger.error(f"Sarvam API error: {e}")
await self.push_error(ErrorFrame(f"Sarvam API error: {e}"))
await self.push_error(error_msg=f"Sarvam API error: {e}", exception=e)
except Exception as e:
logger.error(f"Failed to connect to Sarvam: {e}")
self._socket_client = None
self._websocket_context = None
await self.push_error(ErrorFrame(f"Failed to connect to Sarvam: {e}"))
await self.push_error(error_msg=f"Failed to connect to Sarvam: {e}", exception=e)
async def _disconnect(self):
"""Disconnect from Sarvam WebSocket API using SDK."""
@@ -427,8 +424,7 @@ class SarvamSTTService(STTService):
await self.stop_processing_metrics()
except Exception as e:
logger.error(f"Error handling Sarvam message: {e}")
await self.push_error(ErrorFrame(f"Failed to handle message: {e}"))
await self.push_error(error_msg=f"Failed to handle message: {e}", exception=e)
await self.stop_all_metrics()
@traced_stt

View File

@@ -254,8 +254,7 @@ class SarvamHttpTTSService(TTSService):
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Sarvam API error: {error_text}")
await self.push_error(ErrorFrame(error=f"Sarvam API error: {error_text}"))
await self.push_error(error_msg=f"Sarvam API error: {error_text}")
return
response_data = await response.json()
@@ -264,8 +263,7 @@ class SarvamHttpTTSService(TTSService):
# Decode base64 audio data
if "audios" not in response_data or not response_data["audios"]:
logger.error("No audio data received from Sarvam API")
await self.push_error(ErrorFrame(error="No audio data received"))
await self.push_error(error_msg="No audio data received")
return
# Get the first audio (there should be only one for single text input)
@@ -560,8 +558,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self._disconnect_websocket()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
# Reset state only after everything is cleaned up
self._started = False
@@ -602,8 +599,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self._websocket.send(json.dumps(config_message))
logger.debug("Configuration sent successfully")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
raise
async def _disconnect_websocket(self):

View File

@@ -327,8 +327,7 @@ class SonioxSTTService(STTService):
# Expected when closing the connection
logger.debug("WebSocket connection closed, keepalive task stopped.")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
async def _receive_task_handler(self):
if not self._websocket:
@@ -404,14 +403,7 @@ class SonioxSTTService(STTService):
if error_code or error_message:
# In case of error, still send the final transcript (if any remaining in the buffer).
await send_endpoint_transcript()
logger.error(
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
)
await self.push_error(
ErrorFrame(
error=f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
)
)
await self.push_error(error_msg=f"{self} error: {error_code} (_receive_task_handler) - {error_message}")
finished = content.get("finished")
if finished:

View File

@@ -514,8 +514,7 @@ class SpeechmaticsSTTService(STTService):
self._client.send_message(payload), self.get_event_loop()
)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
raise RuntimeError(f"error sending message to STT: {e}")
async def _connect(self) -> None:
@@ -596,8 +595,7 @@ class SpeechmaticsSTTService(STTService):
except asyncio.TimeoutError:
logger.warning(f"{self} Timeout while closing Speechmatics client connection")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
finally:
self._client = None
await self._call_event_handler("on_disconnected")

View File

@@ -329,4 +329,4 @@ class WebsocketSTTService(STTService, WebsocketService):
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
await self.push_error_frame(error)

View File

@@ -671,7 +671,7 @@ class WebsocketTTSService(TTSService, WebsocketService):
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
await self.push_error_frame(error)
class InterruptibleTTSService(WebsocketTTSService):
@@ -733,7 +733,7 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
await self.push_error_frame(error)
class InterruptibleWordTTSService(WebsocketWordTTSService):

View File

@@ -246,8 +246,7 @@ class UltravoxSTTService(AIService):
logger.info("Model warm-up completed successfully")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(exception=e)
def _generate_silent_audio(self, sample_rate=16000, duration_sec=1.0):
"""Generate silent audio as a numpy array.

View File

@@ -2506,13 +2506,10 @@ class DailyTransport(BaseTransport):
async def _on_error(self, error):
"""Handle error events and push error frames."""
await self._call_event_handler("on_error", error)
# Push error frame to notify the pipeline
error_frame = ErrorFrame(error)
if self._input:
await self._input.push_error(error_frame)
await self._input.push_error(error_msg=error)
elif self._output:
await self._output.push_error(error_frame)
await self._input.push_error(error_msg=error)
else:
logger.error("Both input and output are None while trying to push error")
raise Exception("No valid input or output channel to push error")