From d2037894902fc0fc730b2a56f846a13fe0c14013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 29 May 2025 09:10:45 -0700 Subject: [PATCH] add STT result field to TranscriptionFrame/InterimTranscriptionFrame --- CHANGELOG.md | 5 +++++ src/pipecat/frames/frames.py | 11 +++++++---- src/pipecat/services/assemblyai/stt.py | 16 ++++++++++++++-- src/pipecat/services/aws/stt.py | 2 ++ src/pipecat/services/azure/stt.py | 8 +++++++- src/pipecat/services/deepgram/stt.py | 16 ++++++++++++++-- src/pipecat/services/fal/stt.py | 6 +++++- .../services/gemini_multimodal_live/gemini.py | 5 ++++- src/pipecat/services/gladia/stt.py | 14 ++++++++++++-- src/pipecat/services/google/stt.py | 14 ++++++++++++-- .../services/openai_realtime_beta/openai.py | 4 ++-- src/pipecat/services/riva/stt.py | 14 ++++++++++++-- src/pipecat/transports/services/daily.py | 10 ++++++++-- 13 files changed, 104 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e92701afa..03bc4dae9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- You can now access STT service results through the new + `TranscriptionFrame.result` and `InterimTranscriptionFrame.result` field. This + is useful in case you use some specific settings for the STT and you want to + access the STT results. + - The examples runner is now public from the `pipecat.examples` package. This allows everyone to build their own examples and run them easily. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 758ba7556..b24ff7b19 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -228,14 +228,15 @@ class TTSTextFrame(TextFrame): @dataclass class TranscriptionFrame(TextFrame): - """A text frame with transcription-specific data. Will be placed in the - transport's receive queue when a participant speaks. + """A text frame with transcription-specific data. The `result` field + contains the result from the STT service if available. """ user_id: str timestamp: str language: Optional[Language] = None + result: Optional[Any] = None def __str__(self): return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" @@ -243,14 +244,16 @@ class TranscriptionFrame(TextFrame): @dataclass class InterimTranscriptionFrame(TextFrame): - """A text frame with interim transcription-specific data. Will be placed in - the transport's receive queue when a participant speaks. + """A text frame with interim transcription-specific data. The `result` field + contains the result from the STT service if available. + """ text: str user_id: str timestamp: str language: Optional[Language] = None + result: Optional[Any] = None def __str__(self): return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" diff --git a/src/pipecat/services/assemblyai/stt.py b/src/pipecat/services/assemblyai/stt.py index 41a27fd5b..50e16756e 100644 --- a/src/pipecat/services/assemblyai/stt.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -123,9 +123,21 @@ class AssemblyAISTTService(STTService): language = self._settings["language"] if is_final: - frame = TranscriptionFrame(transcript.text, "", timestamp, language) + frame = TranscriptionFrame( + transcript.text, + "", + timestamp, + language, + result=transcript, + ) else: - frame = InterimTranscriptionFrame(transcript.text, "", timestamp, language) + frame = InterimTranscriptionFrame( + transcript.text, + "", + timestamp, + language, + result=transcript, + ) asyncio.run_coroutine_threadsafe( self._handle_transcription(transcript.text, is_final, language), diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index 5016d5e78..4490f3795 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -305,6 +305,7 @@ class AWSTranscribeSTTService(STTService): "", time_now_iso8601(), self._settings["language"], + result=result, ) ) await self._handle_transcription( @@ -320,6 +321,7 @@ class AWSTranscribeSTTService(STTService): "", time_now_iso8601(), self._settings["language"], + result=result, ) ) elif headers.get(":message-type") == "exception": diff --git a/src/pipecat/services/azure/stt.py b/src/pipecat/services/azure/stt.py index 9c9b386eb..abd8acbd3 100644 --- a/src/pipecat/services/azure/stt.py +++ b/src/pipecat/services/azure/stt.py @@ -121,7 +121,13 @@ class AzureSTTService(STTService): def _on_handle_recognized(self, event): if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0: language = getattr(event.result, "language", None) or self._settings.get("language") - frame = TranscriptionFrame(event.result.text, "", time_now_iso8601(), language) + frame = TranscriptionFrame( + event.result.text, + "", + time_now_iso8601(), + language, + result=event, + ) asyncio.run_coroutine_threadsafe( self._handle_transcription(event.result.text, True, language), self.get_event_loop() ) diff --git a/src/pipecat/services/deepgram/stt.py b/src/pipecat/services/deepgram/stt.py index 7b68089c5..308ad1d1d 100644 --- a/src/pipecat/services/deepgram/stt.py +++ b/src/pipecat/services/deepgram/stt.py @@ -212,14 +212,26 @@ class DeepgramSTTService(STTService): await self.stop_ttfb_metrics() if is_final: await self.push_frame( - TranscriptionFrame(transcript, "", time_now_iso8601(), language) + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + language, + result=result, + ) ) await self._handle_transcription(transcript, is_final, language) await self.stop_processing_metrics() else: # For interim transcriptions, just push the frame without tracing await self.push_frame( - InterimTranscriptionFrame(transcript, "", time_now_iso8601(), language) + InterimTranscriptionFrame( + transcript, + "", + time_now_iso8601(), + language, + result=result, + ) ) async def process_frame(self, frame: Frame, direction: FrameDirection): diff --git a/src/pipecat/services/fal/stt.py b/src/pipecat/services/fal/stt.py index a019694fd..1e26d9958 100644 --- a/src/pipecat/services/fal/stt.py +++ b/src/pipecat/services/fal/stt.py @@ -252,7 +252,11 @@ class FalSTTService(SegmentedSTTService): await self._handle_transcription(text, True, self._settings["language"]) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame( - text, "", time_now_iso8601(), Language(self._settings["language"]) + text, + "", + time_now_iso8601(), + Language(self._settings["language"]), + result=response, ) except Exception as e: diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 8bc6fea53..25377e183 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -937,7 +937,10 @@ class GeminiMultimodalLiveLLMService(LLMService): logger.debug(f"[Transcription:user] [{complete_sentence}]") await self.push_frame( TranscriptionFrame( - text=complete_sentence, user_id="", timestamp=time_now_iso8601() + text=complete_sentence, + user_id="", + timestamp=time_now_iso8601(), + result=evt, ), FrameDirection.UPSTREAM, ) diff --git a/src/pipecat/services/gladia/stt.py b/src/pipecat/services/gladia/stt.py index 33c927158..6ac5edad9 100644 --- a/src/pipecat/services/gladia/stt.py +++ b/src/pipecat/services/gladia/stt.py @@ -408,7 +408,13 @@ class GladiaSTTService(STTService): if confidence >= self._confidence: if is_final: await self.push_frame( - TranscriptionFrame(transcript, "", time_now_iso8601(), language) + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + language, + result=content, + ) ) await self._handle_transcription( transcript=transcript, @@ -418,7 +424,11 @@ class GladiaSTTService(STTService): else: await self.push_frame( InterimTranscriptionFrame( - transcript, "", time_now_iso8601(), language + transcript, + "", + time_now_iso8601(), + language, + result=content, ) ) elif content["type"] == "translation": diff --git a/src/pipecat/services/google/stt.py b/src/pipecat/services/google/stt.py index 49b8fba6f..4fd129af3 100644 --- a/src/pipecat/services/google/stt.py +++ b/src/pipecat/services/google/stt.py @@ -816,7 +816,13 @@ class GoogleSTTService(STTService): if result.is_final: self._last_transcript_was_final = True await self.push_frame( - TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language) + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + primary_language, + result=result, + ) ) await self.stop_processing_metrics() await self._handle_transcription( @@ -829,7 +835,11 @@ class GoogleSTTService(STTService): await self.stop_ttfb_metrics() await self.push_frame( InterimTranscriptionFrame( - transcript, "", time_now_iso8601(), primary_language + transcript, + "", + time_now_iso8601(), + primary_language, + result=result, ) ) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index e705a9b18..579be2ebe 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -464,7 +464,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): if self._send_transcription_frames: await self.push_frame( # no way to get a language code? - InterimTranscriptionFrame(evt.delta, "", time_now_iso8601()) + InterimTranscriptionFrame(evt.delta, "", time_now_iso8601(), result=evt) ) async def handle_evt_input_audio_transcription_completed(self, evt): @@ -473,7 +473,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): if self._send_transcription_frames: await self.push_frame( # no way to get a language code? - TranscriptionFrame(evt.transcript, "", time_now_iso8601()) + TranscriptionFrame(evt.transcript, "", time_now_iso8601(), result=evt) ) pair = self._user_and_response_message_tuple if pair: diff --git a/src/pipecat/services/riva/stt.py b/src/pipecat/services/riva/stt.py index 92a3463c3..1e9d8cd6c 100644 --- a/src/pipecat/services/riva/stt.py +++ b/src/pipecat/services/riva/stt.py @@ -256,7 +256,13 @@ class RivaSTTService(STTService): if result.is_final: await self.stop_processing_metrics() await self.push_frame( - TranscriptionFrame(transcript, "", time_now_iso8601(), self._language_code) + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._language_code, + result=result, + ) ) await self._handle_transcription( transcript=transcript, @@ -266,7 +272,11 @@ class RivaSTTService(STTService): else: await self.push_frame( InterimTranscriptionFrame( - transcript, "", time_now_iso8601(), self._language_code + transcript, + "", + time_now_iso8601(), + self._language_code, + result=result, ) ) diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 075d57dfd..065981b52 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -1601,10 +1601,16 @@ class DailyTransport(BaseTransport): except KeyError: language = None if is_final: - frame = TranscriptionFrame(text, participant_id, timestamp, language) + frame = TranscriptionFrame(text, participant_id, timestamp, language, result=message) logger.debug(f"Transcription (from: {participant_id}): [{text}]") else: - frame = InterimTranscriptionFrame(text, participant_id, timestamp, language) + frame = InterimTranscriptionFrame( + text, + participant_id, + timestamp, + language, + result=message, + ) if self._input: await self._input.push_transcription_frame(frame)