Widen run_stt/run_tts return type to AsyncGenerator[Frame | None, None]

The push-based STT/TTS implementations send audio/text over a socket and
receive results via a separate receive task, so there is nothing to
yield inline. They yield `None` by design. The previous declaration of
`AsyncGenerator[Frame, None]` disagreed with that, while the consumer
(`AIService.process_generator`) already accepted `Frame | None`. Widen
the producer side (abstract base and every subclass) so the type honestly
describes the contract.

Pure annotation change; no runtime behavior difference.
This commit is contained in:
Mark Backman
2026-04-22 11:01:50 -04:00
parent 3f3d3c9203
commit 08fe9157cc
38 changed files with 50 additions and 50 deletions

View File

@@ -418,7 +418,7 @@ class AssemblyAISTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data for speech-to-text conversion.
Args:

View File

@@ -449,7 +449,7 @@ class AsyncAITTSService(WebsocketTTSService):
await super().on_audio_context_completed(context_id)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Async API websocket endpoint.
Args:
@@ -620,7 +620,7 @@ class AsyncAIHttpTTSService(TTSService):
self._output_sample_rate = self.sample_rate
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Async's HTTP streaming API.
Args:

View File

@@ -196,7 +196,7 @@ class AWSTranscribeSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data and send to AWS Transcribe.
Args:

View File

@@ -191,7 +191,7 @@ class AzureSTTService(STTService):
return changed
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data for speech-to-text conversion.
Feeds audio data to the Azure speech recognizer for processing.

View File

@@ -277,7 +277,7 @@ class CartesiaSTTService(WebsocketSTTService):
if self._websocket and self._websocket.state is State.OPEN:
await self._websocket.send("finalize")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data for speech-to-text transcription.
Args:

View File

@@ -660,7 +660,7 @@ class CartesiaTTSService(WebsocketTTSService):
await self._connect_websocket()
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Cartesia's streaming API.
Args:
@@ -873,7 +873,7 @@ class CartesiaHttpTTSService(TTSService):
await self._close_session()
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Cartesia's HTTP API.
Args:

View File

@@ -222,7 +222,7 @@ class DeepgramFluxSageMakerSTTService(DeepgramFluxSTTBase):
# Audio sending and response receiving
# ------------------------------------------------------------------
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Deepgram Flux for transcription.
Args:

View File

@@ -354,7 +354,7 @@ class DeepgramFluxSTTService(DeepgramFluxSTTBase, WebsocketService):
# Audio sending and receiving
# ------------------------------------------------------------------
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Deepgram Flux for transcription.
Transmits raw audio bytes to the Deepgram Flux API for real-time speech

View File

@@ -256,7 +256,7 @@ class DeepgramSageMakerSTTService(STTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Deepgram for transcription.
Args:

View File

@@ -325,7 +325,7 @@ class DeepgramSageMakerTTSService(TTSService):
logger.error(f"{self} error sending Flush message: {e}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Deepgram TTS on SageMaker.
Args:

View File

@@ -514,7 +514,7 @@ class DeepgramSTTService(STTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Deepgram for transcription.
Args:

View File

@@ -330,7 +330,7 @@ class DeepgramTTSService(WebsocketTTSService):
logger.error(f"{self} error sending Flush message: {e}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Deepgram's WebSocket TTS API.
Args:
@@ -441,7 +441,7 @@ class DeepgramHttpTTSService(TTSService):
return True
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Deepgram's TTS API.
Args:

View File

@@ -370,7 +370,7 @@ class ElevenLabsSTTService(SegmentedSTTService):
"""Handle a transcription result with tracing."""
await self.stop_processing_metrics()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Transcribe an audio segment using ElevenLabs' STT API.
Args:
@@ -674,7 +674,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
except Exception as e:
logger.warning(f"Failed to send commit: {e}")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data for speech-to-text transcription.
Args:

View File

@@ -889,7 +889,7 @@ class ElevenLabsTTSService(WebsocketTTSService):
await self._websocket.send(json.dumps(msg))
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using ElevenLabs' streaming WebSocket API.
Args:
@@ -1240,7 +1240,7 @@ class ElevenLabsHttpTTSService(TTSService):
return word_times
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using ElevenLabs streaming API with timestamps.
Makes a request to the ElevenLabs API to generate audio and timing data.

View File

@@ -373,7 +373,7 @@ class FishAudioTTSService(InterruptibleTTSService):
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Fish Audio's streaming API.
Args:

View File

@@ -461,7 +461,7 @@ class GladiaSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Run speech-to-text on audio data.
Args:

View File

@@ -931,7 +931,7 @@ class GoogleSTTService(STTService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process an audio chunk for STT transcription.
Args:

View File

@@ -333,7 +333,7 @@ class GradiumSTTService(WebsocketSTTService):
except Exception as e:
logger.warning(f"Failed to send flush: {e}")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data for speech-to-text conversion.
Args:

View File

@@ -356,7 +356,7 @@ class GradiumTTSService(WebsocketTTSService):
await self.push_error(error_msg=f"Error: {msg.get('message', msg)}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Gradium's streaming API.
Args:

View File

@@ -283,7 +283,7 @@ class InworldHttpTTSService(TTSService):
return (word_times, chunk_end_time)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate TTS audio for the given text.
Args:
@@ -1128,7 +1128,7 @@ class InworldTTSService(WebsocketTTSService):
await self.send_with_retry(json.dumps(msg), self._report_error)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate TTS audio for the given text using the Inworld WebSocket TTS service.
Args:

View File

@@ -336,7 +336,7 @@ class LmntTTSService(InterruptibleTTSService):
logger.error(f"Invalid JSON message: {message}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate TTS audio from text using LMNT's streaming API.
Args:

View File

@@ -185,7 +185,7 @@ class MistralSTTService(STTService):
if self._connection and not self._connection.is_closed:
await self._connection.flush_audio()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Mistral for transcription.
Args:

View File

@@ -366,7 +366,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self._websocket.send(json.dumps(msg))
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Neuphonic's streaming API.
Args:
@@ -565,7 +565,7 @@ class NeuphonicHttpTTSService(TTSService):
return None
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Neuphonic streaming API.
Args:

View File

@@ -395,7 +395,7 @@ class NvidiaSTTService(STTService):
)
)
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data for speech-to-text transcription.
Args:
@@ -661,7 +661,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
"""Handle a transcription result with tracing."""
pass
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Transcribe an audio segment.
Args:

View File

@@ -526,7 +526,7 @@ class NvidiaTTSService(TTSService):
)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using NVIDIA Nemotron Speech TTS.
On the first call for a turn, starts a persistent ``synthesize_online``

View File

@@ -415,7 +415,7 @@ class OpenAIRealtimeSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to the transcription session.
Audio is streamed over the WebSocket. Transcription results arrive

View File

@@ -431,7 +431,7 @@ class ResembleAITTSService(WebsocketTTSService):
await self._connect_websocket()
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Resemble AI's streaming API.
Args:

View File

@@ -603,7 +603,7 @@ class RimeTTSService(WebsocketTTSService):
self.reset_active_audio_context()
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Rime's streaming API.
Args:
@@ -786,7 +786,7 @@ class RimeHttpTTSService(TTSService):
return language_to_rime_language(language)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Rime's HTTP API.
Args:
@@ -1142,7 +1142,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
await self.push_error(error_msg=f"Error: {e}", exception=e)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Rime's streaming API.
Args:

View File

@@ -570,7 +570,7 @@ class SarvamSTTService(STTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Sarvam for transcription.
Args:

View File

@@ -569,7 +569,7 @@ class SarvamHttpTTSService(TTSService):
await super().start(frame)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Sarvam AI's API.
Args:
@@ -1192,7 +1192,7 @@ class SarvamTTSService(InterruptibleTTSService):
logger.warning("WebSocket not ready, cannot send text")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech audio frames from input text using Sarvam TTS.
Sends text over WebSocket for synthesis and yields corresponding audio or status frames.

View File

@@ -247,7 +247,7 @@ class SmallestSTTService(WebsocketSTTService):
except Exception as e:
logger.warning(f"{self} failed to send finalize: {e}")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio to the Smallest Pulse WebSocket for transcription.
Args:

View File

@@ -390,7 +390,7 @@ class SmallestTTSService(InterruptibleTTSService):
logger.warning(f"{self} unknown message status: {msg}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Smallest's WebSocket streaming API.
Args:

View File

@@ -402,7 +402,7 @@ class SonioxSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Send audio data to Soniox STT Service.
Args:

View File

@@ -1059,7 +1059,7 @@ class SpeechmaticsSTTService(STTService):
"""Record transcription event for tracing."""
pass
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Adds audio to the audio buffer and yields None."""
try:
if self._client:

View File

@@ -274,7 +274,7 @@ class STTService(AIService):
return Language(language)
@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Run speech-to-text on the provided audio data.
This method must be implemented by subclasses to provide actual speech

View File

@@ -445,7 +445,7 @@ class TTSService(AIService):
# Converts the text to audio.
@abstractmethod
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Run text-to-speech synthesis on the provided text.
This method must be implemented by subclasses to provide actual TTS functionality.

View File

@@ -209,7 +209,7 @@ class XAISTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Forward raw audio bytes to the xAI STT WebSocket.
Transcription frames are pushed from the receive task, not yielded

View File

@@ -188,7 +188,7 @@ class XAIHttpTTSService(TTSService):
self._session = None
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using xAI's TTS API."""
logger.debug(f"{self}: Generating TTS [{text}]")
@@ -466,7 +466,7 @@ class XAITTSService(InterruptibleTTSService):
logger.debug(f"{self}: unhandled xAI message type: {msg_type}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate TTS audio from text using xAI's streaming WebSocket API."""
logger.debug(f"{self}: Generating TTS [{text}]")