From 55879bf36561052eccbb0d9447c2dabfc01ff5f0 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 13 Dec 2024 15:38:59 -0500 Subject: [PATCH] Add TranscriptionProcessor --- .../07a-interruptible-anthropic.py | 42 ++++- .../28a-transcription-update-openai.py | 128 +++++++++++++++ .../28b-transcription-update-anthropic.py | 128 +++++++++++++++ .../28c-transcription-update-gemini.py | 138 ++++++++++++++++ src/pipecat/frames/frames.py | 36 ++++- .../processors/transcript_processor.py | 150 ++++++++++++++++++ 6 files changed, 613 insertions(+), 9 deletions(-) create mode 100644 examples/foundational/28a-transcription-update-openai.py create mode 100644 examples/foundational/28b-transcription-update-anthropic.py create mode 100644 examples/foundational/28c-transcription-update-gemini.py create mode 100644 src/pipecat/processors/transcript_processor.py diff --git a/examples/foundational/07a-interruptible-anthropic.py b/examples/foundational/07a-interruptible-anthropic.py index e7e680eab..25a301269 100644 --- a/examples/foundational/07a-interruptible-anthropic.py +++ b/examples/foundational/07a-interruptible-anthropic.py @@ -7,6 +7,7 @@ import asyncio import os import sys +from typing import List import aiohttp from dotenv import load_dotenv @@ -14,12 +15,13 @@ from loguru import logger from runner import configure from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import LLMMessagesFrame +from pipecat.frames.frames import Frame, LLMMessagesFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.anthropic import AnthropicLLMService +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.anthropic import AnthropicLLMContext, AnthropicLLMService from pipecat.services.cartesia import CartesiaTTSService from pipecat.transports.services.daily import DailyParams, DailyTransport @@ -29,6 +31,28 @@ logger.remove(0) logger.add(sys.stderr, level="DEBUG") +class TestAnthropicLLMService(AnthropicLLMService): + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, LLMMessagesFrame): + logger.info("Original OpenAI format messages:") + logger.info(frame.messages) + + # Convert to Anthropic format + context = AnthropicLLMContext.from_messages(frame.messages) + logger.info("Converted to Anthropic format:") + logger.info(context.messages) + + # Convert back to OpenAI format + openai_messages = [] + for msg in context.messages: + converted = context.to_standard_messages(msg) + openai_messages.extend(converted) + logger.info("Converted back to OpenAI format:") + logger.info(openai_messages) + + await super().process_frame(frame, direction) + + async def main(): async with aiohttp.ClientSession() as session: (room_url, token) = await configure(session) @@ -50,18 +74,24 @@ async def main(): voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady ) - llm = AnthropicLLMService( + llm = TestAnthropicLLMService( api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229" ) - # todo: think more about how to handle system prompts in a more general way. OpenAI, - # Google, and Anthropic all have slightly different approaches to providing a system - # prompt. + # Test messages including various formats messages = [ { "role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way. Say hello.", }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello! How can I help you today?"}, + {"type": "text", "text": "I'm ready to assist."}, + ], + }, + {"role": "user", "content": "Hi there!"}, ] context = OpenAILLMContext(messages) diff --git a/examples/foundational/28a-transcription-update-openai.py b/examples/foundational/28a-transcription-update-openai.py new file mode 100644 index 000000000..ec103ff82 --- /dev/null +++ b/examples/foundational/28a-transcription-update-openai.py @@ -0,0 +1,128 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys +from typing import List + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class TranscriptHandler: + """Simple handler to demonstrate transcript processing.""" + + def __init__(self): + self.messages: List[TranscriptionMessage] = [] + + async def on_transcript_update( + self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame + ): + """Handle new transcript messages.""" + self.messages.extend(frame.messages) + + # Log the new messages + logger.info("New transcript messages:") + for msg in frame.messages: + logger.info(f"{msg.role}: {msg.content}") + + # Log the full transcript + logger.info("Full transcript:") + for msg in self.messages: + logger.info(f"{msg.role}: {msg.content}") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = OpenAILLMService( + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4o", + ) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way. Say hello.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + # Create transcript processor and handler + transcript_processor = TranscriptProcessor() + transcript_handler = TranscriptHandler() + + # Register event handler for transcript updates + @transcript_processor.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + await transcript_handler.on_transcript_update(processor, frame) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + transcript_processor, # Process transcripts + ] + ) + + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/28b-transcription-update-anthropic.py b/examples/foundational/28b-transcription-update-anthropic.py new file mode 100644 index 000000000..23ee93a21 --- /dev/null +++ b/examples/foundational/28b-transcription-update-anthropic.py @@ -0,0 +1,128 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys +from typing import List + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.services.anthropic import AnthropicLLMService +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class TranscriptHandler: + """Simple handler to demonstrate transcript processing.""" + + def __init__(self): + self.messages: List[TranscriptionMessage] = [] + + async def on_transcript_update( + self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame + ): + """Handle new transcript messages.""" + self.messages.extend(frame.messages) + + # Log the new messages + logger.info("New transcript messages:") + for msg in frame.messages: + logger.info(f"{msg.role}: {msg.content}") + + # Log the full transcript + logger.info("Full transcript:") + for msg in self.messages: + logger.info(f"{msg.role}: {msg.content}") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = AnthropicLLMService( + api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20241022" + ) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.", + }, + {"role": "user", "content": "Say hello."}, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + # Create transcript processor and handler + transcript_processor = TranscriptProcessor() + transcript_handler = TranscriptHandler() + + # Register event handler for transcript updates + @transcript_processor.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + await transcript_handler.on_transcript_update(processor, frame) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + transcript_processor, # Process transcripts + ] + ) + + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/28c-transcription-update-gemini.py b/examples/foundational/28c-transcription-update-gemini.py new file mode 100644 index 000000000..27291a7c9 --- /dev/null +++ b/examples/foundational/28c-transcription-update-gemini.py @@ -0,0 +1,138 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys +from typing import List + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.google import GoogleLLMService +from pipecat.services.openai import OpenAILLMContext +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class TranscriptHandler: + """Simple handler to demonstrate transcript processing.""" + + def __init__(self): + self.messages: List[TranscriptionMessage] = [] + + async def on_transcript_update( + self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame + ): + """Handle new transcript messages.""" + self.messages.extend(frame.messages) + + # Log the new messages + logger.info("New transcript messages:") + for msg in frame.messages: + logger.info(f"{msg.role}: {msg.content}") + + # Log the full transcript + logger.info("Full transcript:") + for msg in self.messages: + logger.info(f"{msg.role}: {msg.content}") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = GoogleLLMService( + model="models/gemini-2.0-flash-exp", + # model="gemini-exp-1114", + api_key=os.getenv("GOOGLE_API_KEY"), + ) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative, helpful, and brief way.", + }, + {"role": "user", "content": "Say hello."}, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + # Create transcript processor and handler + transcript_processor = TranscriptProcessor() + transcript_handler = TranscriptHandler() + + # Register event handler for transcript updates + @transcript_processor.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + await transcript_handler.on_transcript_update(processor, frame) + + pipeline = Pipeline( + [ + transport.input(), + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + transcript_processor, + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index d3792f537..f74d30371 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -5,7 +5,7 @@ # from dataclasses import dataclass, field -from typing import Any, List, Mapping, Optional, Tuple +from typing import Any, List, Literal, Mapping, Optional, Tuple, TypeAlias from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.clocks.base_clock import BaseClock @@ -195,7 +195,8 @@ class TranscriptionFrame(TextFrame): @dataclass class InterimTranscriptionFrame(TextFrame): """A text frame with interim transcription-specific data. Will be placed in - the transport's receive queue when a participant speaks.""" + the transport's receive queue when a participant speaks. + """ text: str user_id: str @@ -206,6 +207,34 @@ class InterimTranscriptionFrame(TextFrame): return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" +@dataclass +class TranscriptionMessage: + """A message in a conversation transcript containing the role and content. + + Messages are in standard format with roles normalized to user/assistant. + """ + + role: Literal["user", "assistant"] + content: str + timestamp: str | None = None + + +@dataclass +class TranscriptionUpdateFrame(DataFrame): + """A frame containing new messages added to the conversation transcript. + + This frame is emitted when new messages are added to the conversation history, + containing only the newly added messages rather than the full transcript. + Messages have normalized roles (user/assistant) regardless of the LLM service used. + """ + + messages: List[TranscriptionMessage] + + def __str__(self): + pts = format_pts(self.pts) + return f"{self.name}(pts: {pts}, messages: {len(self.messages)})" + + @dataclass class LLMMessagesFrame(DataFrame): """A frame containing a list of LLM messages. Used to signal that an LLM @@ -546,7 +575,8 @@ class EndFrame(ControlFrame): @dataclass class LLMFullResponseStartFrame(ControlFrame): """Used to indicate the beginning of an LLM response. Following by one or - more TextFrame and a final LLMFullResponseEndFrame.""" + more TextFrame and a final LLMFullResponseEndFrame. + """ pass diff --git a/src/pipecat/processors/transcript_processor.py b/src/pipecat/processors/transcript_processor.py new file mode 100644 index 000000000..97173f967 --- /dev/null +++ b/src/pipecat/processors/transcript_processor.py @@ -0,0 +1,150 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import List + +from loguru import logger + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TranscriptionMessage, + TranscriptionUpdateFrame, +) +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class TranscriptProcessor(FrameProcessor): + """Processes LLM context frames to generate conversation transcripts. + + This processor monitors OpenAILLMContextFrame frames and extracts conversation + content, filtering out system messages and function calls. When new messages + are detected, it emits a TranscriptionUpdateFrame containing only the new + messages. + + Each LLM context (OpenAI, Anthropic, Google) provides conversion to the standard format: + [ + { + "role": "user", + "content": [{"type": "text", "text": "Hi, how are you?"}] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "Great! And you?"}] + } + ] + + Events: + on_transcript_update: Emitted when new transcript messages are available. + Args: TranscriptionUpdateFrame containing new messages. + + Example: + ```python + transcript_processor = TranscriptProcessor() + + @transcript_processor.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + for msg in frame.messages: + print(f"{msg.role}: {msg.content}") + ``` + """ + + def __init__(self, **kwargs): + """Initialize the transcript processor. + + Args: + **kwargs: Additional arguments passed to FrameProcessor + """ + super().__init__(**kwargs) + self._processed_messages: List[TranscriptionMessage] = [] + self._register_event_handler("on_transcript_update") + + def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]: + """Extract conversation messages from standard format. + + Args: + messages: List of messages in standard format with structured content + + Returns: + List[TranscriptionMessage]: Normalized conversation messages + """ + result = [] + for msg in messages: + # Only process user and assistant messages + if msg["role"] not in ("user", "assistant"): + continue + + content = msg.get("content", []) + if isinstance(content, list): + # Extract text from structured content + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text_parts.append(part["text"]) + + if text_parts: + result.append( + TranscriptionMessage(role=msg["role"], content=" ".join(text_parts)) + ) + + return result + + def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[TranscriptionMessage]: + """Find messages in current that aren't in self._processed_messages. + + Args: + current: List of current messages + + Returns: + List[TranscriptionMessage]: New messages not yet processed + """ + if not self._processed_messages: + return current + + processed_len = len(self._processed_messages) + if len(current) <= processed_len: + return [] + + return current[processed_len:] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames, watching for OpenAILLMContextFrame. + + Args: + frame: The frame to process + direction: Frame processing direction + + Raises: + ErrorFrame: If message processing fails + """ + await super().process_frame(frame, direction) + + if isinstance(frame, OpenAILLMContextFrame): + try: + # Convert context messages to standard format + standard_messages = [] + for msg in frame.context.messages: + converted = frame.context.to_standard_messages(msg) + standard_messages.extend(converted) + + # Extract and process messages + current_messages = self._extract_messages(standard_messages) + new_messages = self._find_new_messages(current_messages) + + if new_messages: + # Update state and notify listeners + self._processed_messages.extend(new_messages) + update_frame = TranscriptionUpdateFrame(messages=new_messages) + await self._call_event_handler("on_transcript_update", update_frame) + await self.push_frame(update_frame) + + except Exception as e: + logger.error(f"Error processing transcript in {self}: {e}") + await self.push_error(ErrorFrame(str(e))) + + # Always push the original frame downstream + await self.push_frame(frame, direction)