From 90b7f65545b7f61d0c40efd7ea19b315b8940201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 15 Oct 2024 09:24:06 -0700 Subject: [PATCH] rtvi: add RTVIBotTranscriptionProcessor to send `bot-transcription` --- CHANGELOG.md | 4 +++ src/pipecat/processors/frameworks/rtvi.py | 30 ++++++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00566d481..4a98700df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `RTVIBotTranscriptionProcessor` which will send the RTVI + `bot-transcription` protocol message. These are TTS text aggregated (into + sentences) messages. + - Added new input params to the `MarkdownTextFilter` utility. You can set `filter_code` to filter code from text and `filter_tables` to filter tables from text. diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index f1e7ea023..1616b6790 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -42,6 +42,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContextFrame, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.string import match_endofsentence RTVI_PROTOCOL_VERSION = "0.2" @@ -275,6 +276,12 @@ class RTVITextMessageData(BaseModel): text: str +class RTVIBotTranscriptionMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-transcription"] = "bot-transcription" + data: RTVITextMessageData + + class RTVIBotLLMTextMessage(BaseModel): label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["bot-llm-text"] = "bot-llm-text" @@ -437,16 +444,33 @@ class RTVIUserLLMTextProcessor(RTVIFrameProcessor): if message["role"] == "user": content = message["content"] if isinstance(content, list): - print("LIST") text = " ".join(item["text"] for item in content if "text" in item) else: - print("STRING") text = content - rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text)) await self._push_transport_message_urgent(rtvi_message) +class RTVIBotTranscriptionProcessor(RTVIFrameProcessor): + def __init__(self): + super().__init__() + self._aggregation = "" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TextFrame): + self._aggregation += frame.text + if match_endofsentence(self._aggregation): + message = RTVIBotTranscriptionMessage( + data=RTVITextMessageData(text=self._aggregation) + ) + await self._push_transport_message_urgent(message) + self._aggregation = "" + + class RTVIBotLLMProcessor(RTVIFrameProcessor): def __init__(self, **kwargs): super().__init__(**kwargs)