processors(realtime-ai): add transcription messages

This commit is contained in:
Aleix Conchillo Flaqué
2024-07-19 17:33:49 -07:00
parent 13827e1282
commit f094c42728

View File

@@ -11,12 +11,14 @@ from pydantic import BaseModel, ValidationError
from pipecat.frames.frames import (
Frame,
InterimTranscriptionFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMModelUpdateFrame,
StartFrame,
TTSSpeakFrame,
TTSVoiceUpdateFrame,
TranscriptionFrame,
TransportMessageFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
@@ -93,6 +95,24 @@ class RealtimeAILLMContextResponse(BaseModel):
messages: List[dict]
class RealtimeAITranscriptionMessageData(BaseModel):
text: str
user_id: str
timestamp: str
class RealtimeAITranscriptionMessage(BaseModel):
tag: Literal["realtime-ai"] = "realtime-ai"
type: Literal["user-transcription"] = "user-transcription"
data: RealtimeAITranscriptionMessageData
class RealtimeAIInterimTranscriptionMessage(BaseModel):
tag: Literal["realtime-ai"] = "realtime-ai"
type: Literal["user-interim-transcription"] = "user-interim-transcription"
data: RealtimeAITranscriptionMessageData
class RealtimeAIUserStartedSpeakingMessage(BaseModel):
tag: Literal["realtime-ai"] = "realtime-ai"
type: Literal["user-started-speaking"] = "user-started-speaking"
@@ -137,9 +157,31 @@ class RealtimeAIProcessor(FrameProcessor):
if isinstance(frame, StartFrame):
self._start_frame = frame
await self._handle_setup(self._setup)
elif isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame):
await self._handle_transcriptions(frame)
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_interruptions(frame)
# TODO(aleix): Once we add support for using custom piplines, the STTs will
# be in the pipeline after this processor. This means the STT will have to
# push transcriptions upstream as well.
async def _handle_transcriptions(self, frame: Frame):
message = None
if isinstance(frame, TranscriptionFrame):
message = RealtimeAITranscriptionMessage(
data=RealtimeAITranscriptionMessageData(
text=frame.text,
user_id=frame.user_id,
timestamp=frame.timestamp))
elif isinstance(frame, InterimTranscriptionFrame):
message = RealtimeAIInterimTranscriptionMessage(
data=RealtimeAITranscriptionMessageData(
text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp))
if message:
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
async def _handle_interruptions(self, frame: Frame):
message = None
if isinstance(frame, UserStartedSpeakingFrame):