add STT result field to TranscriptionFrame/InterimTranscriptionFrame

This commit is contained in:
Aleix Conchillo Flaqué
2025-05-29 09:10:45 -07:00
parent 7ea0e31cd4
commit d203789490
13 changed files with 104 additions and 21 deletions

View File

@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- You can now access STT service results through the new
`TranscriptionFrame.result` and `InterimTranscriptionFrame.result` field. This
is useful in case you use some specific settings for the STT and you want to
access the STT results.
- The examples runner is now public from the `pipecat.examples` package. This
allows everyone to build their own examples and run them easily.

View File

@@ -228,14 +228,15 @@ class TTSTextFrame(TextFrame):
@dataclass
class TranscriptionFrame(TextFrame):
"""A text frame with transcription-specific data. Will be placed in the
transport's receive queue when a participant speaks.
"""A text frame with transcription-specific data. The `result` field
contains the result from the STT service if available.
"""
user_id: str
timestamp: str
language: Optional[Language] = None
result: Optional[Any] = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
@@ -243,14 +244,16 @@ class TranscriptionFrame(TextFrame):
@dataclass
class InterimTranscriptionFrame(TextFrame):
"""A text frame with interim transcription-specific data. Will be placed in
the transport's receive queue when a participant speaks.
"""A text frame with interim transcription-specific data. The `result` field
contains the result from the STT service if available.
"""
text: str
user_id: str
timestamp: str
language: Optional[Language] = None
result: Optional[Any] = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"

View File

@@ -123,9 +123,21 @@ class AssemblyAISTTService(STTService):
language = self._settings["language"]
if is_final:
frame = TranscriptionFrame(transcript.text, "", timestamp, language)
frame = TranscriptionFrame(
transcript.text,
"",
timestamp,
language,
result=transcript,
)
else:
frame = InterimTranscriptionFrame(transcript.text, "", timestamp, language)
frame = InterimTranscriptionFrame(
transcript.text,
"",
timestamp,
language,
result=transcript,
)
asyncio.run_coroutine_threadsafe(
self._handle_transcription(transcript.text, is_final, language),

View File

@@ -305,6 +305,7 @@ class AWSTranscribeSTTService(STTService):
"",
time_now_iso8601(),
self._settings["language"],
result=result,
)
)
await self._handle_transcription(
@@ -320,6 +321,7 @@ class AWSTranscribeSTTService(STTService):
"",
time_now_iso8601(),
self._settings["language"],
result=result,
)
)
elif headers.get(":message-type") == "exception":

View File

@@ -121,7 +121,13 @@ class AzureSTTService(STTService):
def _on_handle_recognized(self, event):
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
language = getattr(event.result, "language", None) or self._settings.get("language")
frame = TranscriptionFrame(event.result.text, "", time_now_iso8601(), language)
frame = TranscriptionFrame(
event.result.text,
"",
time_now_iso8601(),
language,
result=event,
)
asyncio.run_coroutine_threadsafe(
self._handle_transcription(event.result.text, True, language), self.get_event_loop()
)

View File

@@ -212,14 +212,26 @@ class DeepgramSTTService(STTService):
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(transcript, "", time_now_iso8601(), language)
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
language,
result=result,
)
)
await self._handle_transcription(transcript, is_final, language)
await self.stop_processing_metrics()
else:
# For interim transcriptions, just push the frame without tracing
await self.push_frame(
InterimTranscriptionFrame(transcript, "", time_now_iso8601(), language)
InterimTranscriptionFrame(
transcript,
"",
time_now_iso8601(),
language,
result=result,
)
)
async def process_frame(self, frame: Frame, direction: FrameDirection):

View File

@@ -252,7 +252,11 @@ class FalSTTService(SegmentedSTTService):
await self._handle_transcription(text, True, self._settings["language"])
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text, "", time_now_iso8601(), Language(self._settings["language"])
text,
"",
time_now_iso8601(),
Language(self._settings["language"]),
result=response,
)
except Exception as e:

View File

@@ -937,7 +937,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
logger.debug(f"[Transcription:user] [{complete_sentence}]")
await self.push_frame(
TranscriptionFrame(
text=complete_sentence, user_id="", timestamp=time_now_iso8601()
text=complete_sentence,
user_id="",
timestamp=time_now_iso8601(),
result=evt,
),
FrameDirection.UPSTREAM,
)

View File

@@ -408,7 +408,13 @@ class GladiaSTTService(STTService):
if confidence >= self._confidence:
if is_final:
await self.push_frame(
TranscriptionFrame(transcript, "", time_now_iso8601(), language)
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
language,
result=content,
)
)
await self._handle_transcription(
transcript=transcript,
@@ -418,7 +424,11 @@ class GladiaSTTService(STTService):
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript, "", time_now_iso8601(), language
transcript,
"",
time_now_iso8601(),
language,
result=content,
)
)
elif content["type"] == "translation":

View File

@@ -816,7 +816,13 @@ class GoogleSTTService(STTService):
if result.is_final:
self._last_transcript_was_final = True
await self.push_frame(
TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language)
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
primary_language,
result=result,
)
)
await self.stop_processing_metrics()
await self._handle_transcription(
@@ -829,7 +835,11 @@ class GoogleSTTService(STTService):
await self.stop_ttfb_metrics()
await self.push_frame(
InterimTranscriptionFrame(
transcript, "", time_now_iso8601(), primary_language
transcript,
"",
time_now_iso8601(),
primary_language,
result=result,
)
)

View File

@@ -464,7 +464,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
if self._send_transcription_frames:
await self.push_frame(
# no way to get a language code?
InterimTranscriptionFrame(evt.delta, "", time_now_iso8601())
InterimTranscriptionFrame(evt.delta, "", time_now_iso8601(), result=evt)
)
async def handle_evt_input_audio_transcription_completed(self, evt):
@@ -473,7 +473,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
if self._send_transcription_frames:
await self.push_frame(
# no way to get a language code?
TranscriptionFrame(evt.transcript, "", time_now_iso8601())
TranscriptionFrame(evt.transcript, "", time_now_iso8601(), result=evt)
)
pair = self._user_and_response_message_tuple
if pair:

View File

@@ -256,7 +256,13 @@ class RivaSTTService(STTService):
if result.is_final:
await self.stop_processing_metrics()
await self.push_frame(
TranscriptionFrame(transcript, "", time_now_iso8601(), self._language_code)
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._language_code,
result=result,
)
)
await self._handle_transcription(
transcript=transcript,
@@ -266,7 +272,11 @@ class RivaSTTService(STTService):
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript, "", time_now_iso8601(), self._language_code
transcript,
"",
time_now_iso8601(),
self._language_code,
result=result,
)
)

View File

@@ -1601,10 +1601,16 @@ class DailyTransport(BaseTransport):
except KeyError:
language = None
if is_final:
frame = TranscriptionFrame(text, participant_id, timestamp, language)
frame = TranscriptionFrame(text, participant_id, timestamp, language, result=message)
logger.debug(f"Transcription (from: {participant_id}): [{text}]")
else:
frame = InterimTranscriptionFrame(text, participant_id, timestamp, language)
frame = InterimTranscriptionFrame(
text,
participant_id,
timestamp,
language,
result=message,
)
if self._input:
await self._input.push_transcription_frame(frame)