add STT result field to TranscriptionFrame/InterimTranscriptionFrame
This commit is contained in:
@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- You can now access STT service results through the new
|
||||
`TranscriptionFrame.result` and `InterimTranscriptionFrame.result` field. This
|
||||
is useful in case you use some specific settings for the STT and you want to
|
||||
access the STT results.
|
||||
|
||||
- The examples runner is now public from the `pipecat.examples` package. This
|
||||
allows everyone to build their own examples and run them easily.
|
||||
|
||||
|
||||
@@ -228,14 +228,15 @@ class TTSTextFrame(TextFrame):
|
||||
|
||||
@dataclass
|
||||
class TranscriptionFrame(TextFrame):
|
||||
"""A text frame with transcription-specific data. Will be placed in the
|
||||
transport's receive queue when a participant speaks.
|
||||
"""A text frame with transcription-specific data. The `result` field
|
||||
contains the result from the STT service if available.
|
||||
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
timestamp: str
|
||||
language: Optional[Language] = None
|
||||
result: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
|
||||
@@ -243,14 +244,16 @@ class TranscriptionFrame(TextFrame):
|
||||
|
||||
@dataclass
|
||||
class InterimTranscriptionFrame(TextFrame):
|
||||
"""A text frame with interim transcription-specific data. Will be placed in
|
||||
the transport's receive queue when a participant speaks.
|
||||
"""A text frame with interim transcription-specific data. The `result` field
|
||||
contains the result from the STT service if available.
|
||||
|
||||
"""
|
||||
|
||||
text: str
|
||||
user_id: str
|
||||
timestamp: str
|
||||
language: Optional[Language] = None
|
||||
result: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
|
||||
|
||||
@@ -123,9 +123,21 @@ class AssemblyAISTTService(STTService):
|
||||
language = self._settings["language"]
|
||||
|
||||
if is_final:
|
||||
frame = TranscriptionFrame(transcript.text, "", timestamp, language)
|
||||
frame = TranscriptionFrame(
|
||||
transcript.text,
|
||||
"",
|
||||
timestamp,
|
||||
language,
|
||||
result=transcript,
|
||||
)
|
||||
else:
|
||||
frame = InterimTranscriptionFrame(transcript.text, "", timestamp, language)
|
||||
frame = InterimTranscriptionFrame(
|
||||
transcript.text,
|
||||
"",
|
||||
timestamp,
|
||||
language,
|
||||
result=transcript,
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._handle_transcription(transcript.text, is_final, language),
|
||||
|
||||
@@ -305,6 +305,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
@@ -320,6 +321,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
|
||||
@@ -121,7 +121,13 @@ class AzureSTTService(STTService):
|
||||
def _on_handle_recognized(self, event):
|
||||
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
|
||||
language = getattr(event.result, "language", None) or self._settings.get("language")
|
||||
frame = TranscriptionFrame(event.result.text, "", time_now_iso8601(), language)
|
||||
frame = TranscriptionFrame(
|
||||
event.result.text,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=event,
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._handle_transcription(event.result.text, True, language), self.get_event_loop()
|
||||
)
|
||||
|
||||
@@ -212,14 +212,26 @@ class DeepgramSTTService(STTService):
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), language)
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# For interim transcriptions, just push the frame without tracing
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(transcript, "", time_now_iso8601(), language)
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -252,7 +252,11 @@ class FalSTTService(SegmentedSTTService):
|
||||
await self._handle_transcription(text, True, self._settings["language"])
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text, "", time_now_iso8601(), Language(self._settings["language"])
|
||||
text,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
Language(self._settings["language"]),
|
||||
result=response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -937,7 +937,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
logger.debug(f"[Transcription:user] [{complete_sentence}]")
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text=complete_sentence, user_id="", timestamp=time_now_iso8601()
|
||||
text=complete_sentence,
|
||||
user_id="",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=evt,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
@@ -408,7 +408,13 @@ class GladiaSTTService(STTService):
|
||||
if confidence >= self._confidence:
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), language)
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
@@ -418,7 +424,11 @@ class GladiaSTTService(STTService):
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript, "", time_now_iso8601(), language
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
elif content["type"] == "translation":
|
||||
|
||||
@@ -816,7 +816,13 @@ class GoogleSTTService(STTService):
|
||||
if result.is_final:
|
||||
self._last_transcript_was_final = True
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language)
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
primary_language,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self.stop_processing_metrics()
|
||||
await self._handle_transcription(
|
||||
@@ -829,7 +835,11 @@ class GoogleSTTService(STTService):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript, "", time_now_iso8601(), primary_language
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
primary_language,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -464,7 +464,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
# no way to get a language code?
|
||||
InterimTranscriptionFrame(evt.delta, "", time_now_iso8601())
|
||||
InterimTranscriptionFrame(evt.delta, "", time_now_iso8601(), result=evt)
|
||||
)
|
||||
|
||||
async def handle_evt_input_audio_transcription_completed(self, evt):
|
||||
@@ -473,7 +473,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
# no way to get a language code?
|
||||
TranscriptionFrame(evt.transcript, "", time_now_iso8601())
|
||||
TranscriptionFrame(evt.transcript, "", time_now_iso8601(), result=evt)
|
||||
)
|
||||
pair = self._user_and_response_message_tuple
|
||||
if pair:
|
||||
|
||||
@@ -256,7 +256,13 @@ class RivaSTTService(STTService):
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), self._language_code)
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
@@ -266,7 +272,11 @@ class RivaSTTService(STTService):
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript, "", time_now_iso8601(), self._language_code
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1601,10 +1601,16 @@ class DailyTransport(BaseTransport):
|
||||
except KeyError:
|
||||
language = None
|
||||
if is_final:
|
||||
frame = TranscriptionFrame(text, participant_id, timestamp, language)
|
||||
frame = TranscriptionFrame(text, participant_id, timestamp, language, result=message)
|
||||
logger.debug(f"Transcription (from: {participant_id}): [{text}]")
|
||||
else:
|
||||
frame = InterimTranscriptionFrame(text, participant_id, timestamp, language)
|
||||
frame = InterimTranscriptionFrame(
|
||||
text,
|
||||
participant_id,
|
||||
timestamp,
|
||||
language,
|
||||
result=message,
|
||||
)
|
||||
|
||||
if self._input:
|
||||
await self._input.push_transcription_frame(frame)
|
||||
|
||||
Reference in New Issue
Block a user