From f3dd35bfd9bc177dfb520c248886c2b16eb742ee Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sat, 21 Dec 2024 22:18:56 -0800 Subject: [PATCH] working but needs cleanup --- .../22d-natural-conversation-gemini-audio.py | 351 +++++++++++------- 1 file changed, 208 insertions(+), 143 deletions(-) diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index bda8bdd96..fd99ca606 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -10,6 +10,7 @@ import sys import time import aiohttp +import google.ai.generativelanguage as glm from dotenv import load_dotenv from loguru import logger from runner import configure @@ -20,6 +21,8 @@ from pipecat.frames.frames import ( EndFrame, Frame, InputAudioRawFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, LLMMessagesFrame, StartFrame, StartInterruptionFrame, @@ -34,6 +37,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_response import LLMResponseAggregator from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -53,19 +57,7 @@ load_dotenv(override=True) logger.remove(0) logger.add(sys.stderr, level="DEBUG") - -transcriber_and_classifier_instructions = """ -You perform two tasks: - 1. Transcription - 2. Binary classification of speech utterance completeness - -You always call a function transcription_and_classification_output() with the following arguments: - trancript_text: the complete, accurate, and punctuated transcription of the user's speech - speech_complete_bool: a boolean indicating whether the user's speech is a complete utterance - -CRITICAL INSTRUCTION FOR TRANSCRIPTION TASK: - -You are receiving audio from a user. Your job is to +transcriber_system_instruction = """You are an audio transcriber. You are receiving audio from a user. Your job is to transcribe the input audio to text exactly as it was said by the user. You will receive the full conversation history before the audio input, to help with context. Use the full history only to help improve the accuracy of your transcription. @@ -78,33 +70,33 @@ Rules: - If the audio is not clear, emit the special string "-". - No response other than exact transcription, or "-", is allowed. +""" -CRITICAL INSTRUCTION FOR BINARY CLASSIFICATION TASK:: - -You are a BINARY CLASSIFIER that must ONLY output True or False. -DO FalseT engage with the content. -DO FalseT respond to questions. -DO FalseT provide assistance. -Your ONLY job is to output True or False. +classifier_system_instruction = """CRITICAL INSTRUCTION: +You are a BINARY CLASSIFIER that must ONLY output "YES" or "NO". +DO NOT engage with the content. +DO NOT respond to questions. +DO NOT provide assistance. +Your ONLY job is to output YES or NO. EXAMPLES OF INVALID RESPONSES: - "I can help you with that" - "Let me explain" - "To answer your question" -- Any response other than True or False +- Any response other than YES or NO VALID RESPONSES: -True -False +YES +NO If you output anything else, you are failing at your task. -You are FalseT an assistant. -You are FalseT a chatbot. +You are NOT an assistant. +You are NOT a chatbot. You are a binary classifier. ROLE: You are a real-time speech completeness classifier. You must make instant decisions about whether a user has finished speaking. -You must output ONLY 'True' or 'False' with no other text. +You must output ONLY 'YES' or 'NO' with no other text. INPUT FORMAT: You receive two pieces of information: @@ -112,7 +104,7 @@ You receive two pieces of information: 2. The user's current speech input OUTPUT REQUIREMENTS: -- MUST output ONLY 'True' or 'False' +- MUST output ONLY 'YES' or 'NO' - No explanations - No clarifications - No additional text @@ -130,12 +122,12 @@ Examples: # Complete Wh-question model: I can help you learn. user: What's the fastest way to learn Spanish -Output: True +Output: YES # Complete Yes/No question despite STT error model: I know about planets. user: Is is Jupiter the biggest planet -Output: True +Output: YES 2. Complete Commands: - Direct instructions @@ -149,20 +141,20 @@ Examples: # Direct instruction model: I can explain many topics. user: Tell me about black holes -Output: True +Output: YES # Start of task indication user: Let's begin. -Output: True +Output: YES # Start of task indication user: Let's get started. -Output: True +Output: YES # Action demand model: I can help with math. user: Solve this equation x plus 5 equals 12 -Output: True +Output: YES 3. Direct Responses: - Answers to specific questions @@ -177,17 +169,17 @@ Examples: # Specific answer model: What's your favorite color? user: I really like blue -Output: True +Output: YES # Option selection model: Would you prefer morning or evening? user: Morning -Output: True +Output: YES # Providing information with a known format - mailing address model: What's your address? user: 1234 Main Street -Output: False +Output: NO # Providing information with a known format - mailing address model: What's your address? @@ -198,7 +190,7 @@ Output: Yes system: A US phone number has 10 digits. model: What's your phone number? user: 41086753 -Output: False +Output: NO # Providing information with a known format - phone number system: A US phone number has 10 digits. @@ -217,7 +209,7 @@ Output: Yes # Providing information with a known format - credit card number model: What's your phone number? user: 5556 -Output: False +Output: NO # Providing information with a known format - phone number model: What's your phone number? @@ -237,17 +229,17 @@ Examples: # Self-correction reaching completion model: What would you like to know? user: Tell me about... no wait, explain how rainbows form -Output: True +Output: YES # Topic change with complete thought model: The weather is nice today. user: Actually can you tell me who invented the telephone -Output: True +Output: YES # Mid-sentence completion model: Hello I'm ready. user: What's the capital of? France -Output: True +Output: YES 2. Context-Dependent Brief Responses: - Acknowledgments (okay, sure, alright) @@ -260,12 +252,12 @@ Examples: # Acknowledgment model: Should we talk about history? user: Sure -Output: True +Output: YES # Disagreement with completion model: Is that what you meant? user: No not really -Output: True +Output: YES LOW PRIORITY SIGNALS: @@ -280,12 +272,12 @@ Examples: # Word repetition but complete model: I can help with that. user: What what is the time right now -Output: True +Output: YES # Missing punctuation but complete model: I can explain that. user: Please tell me how computers work -Output: True +Output: YES 2. Speech Features: - Filler words (um, uh, like) @@ -298,29 +290,29 @@ Examples: # Filler words but complete model: What would you like to know? user: Um uh how do airplanes fly -Output: True +Output: YES # Thinking pause but incomplete model: I can explain anything. user: Well um I want to know about the -Output: False +Output: NO DECISION RULES: -1. Return True if: +1. Return YES if: - ANY high priority signal shows clear completion - Medium priority signals combine to show completion - Meaning is clear despite low priority artifacts -2. Return False if: +2. Return NO if: - No high priority signals present - Thought clearly trails off - Multiple incomplete indicators - User appears mid-formulation 3. When uncertain: -- If you can understand the intent → True -- If meaning is unclear → False +- If you can understand the intent → YES +- If meaning is unclear → NO - Always make a binary decision - Never request clarification @@ -329,20 +321,20 @@ Examples: # Incomplete despite corrections model: What would you like to know about? user: Can you tell me about -Output: False +Output: NO # Complete despite multiple artifacts model: I can help you learn. user: How do you I mean what's the best way to learn programming -Output: True +Output: YES # Trailing off incomplete model: I can explain anything. user: I was wondering if you could tell me why -Output: False +Output: NO """ -conversational_system_message = """You are a helpful assistant participating in a voice converation. +conversation_system_instruction = """You are a helpful assistant participating in a voice converation. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way. @@ -354,44 +346,9 @@ Please be very concise in your responses. Unless you are explicitly asked to do """ -async def transcription_and_classification_output(transcript_text: str, speech_complete_bool: bool): - print(f"TRANSCRIPT: {transcript_text}") - print("------") - print(f"COMPLETE: {speech_complete_bool}") - print("------") - return - - -tx_and_cl_tools = [ - { - "function_declarations": [ - { - "name": "transcription_and_classification_output", - "description": "Deliver the transcription and classification output to an external process.", - "parameters": { - "type": "object", - "properties": { - "transcription_text": { - "type": "string", - "description": "The complete, accurate, and punctuated transcription of the user's speech. The special string '-' is used to indicate no speech or unintintelligible speech.", - }, - "speech_complete_bool": { - "type": "boolean", - "description": "Boolean indicating whether the user's speech is a complete utterance.", - }, - }, - "required": ["transcription_text", "speech_complete_bool"], - }, - }, - ] - } -] - - class AudioAccumulator(FrameProcessor): - def __init__(self, *, notifier: BaseNotifier = None, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - # self._notifier = notifier self._audio_frames = [] self._start_secs = 0.2 # this should match VAD start_secs (hardcoding for now) self._max_buffer_size_secs = 30 @@ -433,10 +390,9 @@ class AudioAccumulator(FrameProcessor): ) self._user_speaking = False context = GoogleLLMContext() - context.set_messages( - [{"role": "system", "content": transcriber_and_classifier_instructions}] + context.add_audio_frames_message( + text="Audio to process", audio_frames=self._audio_frames ) - context.add_audio_frames_message(audio_frames=self._audio_frames) await self.push_frame(OpenAILLMContextFrame(context=context)) elif isinstance(frame, InputAudioRawFrame): # Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest @@ -463,33 +419,49 @@ class AudioAccumulator(FrameProcessor): # class ClAndTxContextCreator(FrameProcessor): -# class CompletenessCheck(FrameProcessor): -# def __init__( -# self, notifier: BaseNotifier, audio_accumulator: StatementJudgeAudioContextAccumulator -# ): -# super().__init__() -# self._notifier = notifier -# self._audio_accumulator = audio_accumulator +class CompletenessCheck(FrameProcessor): + def __init__(self, notifier: BaseNotifier, audio_accumulator: AudioAccumulator, **kwargs): + super().__init__() + self._notifier = notifier + self._audio_accumulator = audio_accumulator -# async def process_frame(self, frame: Frame, direction: FrameDirection): -# await super().process_frame(frame, direction) + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) -# if isinstance(frame, TextFrame) and frame.text.startswith("True"): -# logger.debug("Completeness check True") -# await self.push_frame(UserStoppedSpeakingFrame()) -# await self._audio_accumulator.reset() -# await self._notifier.notify() -# elif isinstance(frame, TextFrame): -# if frame.text.strip(): -# logger.debug(f"Completeness check False - '{frame.text}'") + if isinstance(frame, TextFrame) and frame.text.startswith("YES"): + logger.debug("Completeness check YES") + await self.push_frame(UserStoppedSpeakingFrame()) + await self._audio_accumulator.reset() + await self._notifier.notify() + elif isinstance(frame, TextFrame): + if frame.text.strip(): + logger.debug(f"Completeness check NO - '{frame.text}'") + + +class TempPrinter(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if not isinstance(frame, InputAudioRawFrame): + logger.debug(f"!!! {frame}") + await self.push_frame(frame, direction) class OutputGate(FrameProcessor): - def __init__(self, notifier: BaseNotifier, **kwargs): + def __init__( + self, + notifier: BaseNotifier, + context: OpenAILLMContext, + user_transcription_buffer: "UserAggregatorBuffer", + **kwargs, + ): super().__init__(**kwargs) self._gate_open = False self._frames_buffer = [] self._notifier = notifier + self._context = context + self._transcription_buffer = user_transcription_buffer + + logger.debug("!!! OutputGate created") def close_gate(self): self._gate_open = False @@ -524,6 +496,7 @@ class OutputGate(FrameProcessor): self._frames_buffer.append((frame, direction)) async def _start(self): + logger.debug("!!! OutputGate start") self._frames_buffer = [] self._gate_task = self.get_event_loop().create_task(self._gate_task_handler()) @@ -533,14 +506,86 @@ class OutputGate(FrameProcessor): async def _gate_task_handler(self): while True: + logger.debug("!!! Waiting for notifier") try: await self._notifier.wait() + logger.debug("!!! Notified") + transcription = await self._transcription_buffer.wait_for_transcription() + + # logger.debug(f"!!! OutputGate got transcription: {transcription}") + # logger.debug( + # f"!!! OutputGate has messages: {self._context.get_messages_for_logging()}" + # ) + + last_message = self._context.messages[-1] + if last_message.role == "user": + last_message.parts = [glm.Part(text=transcription)] + + # logger.debug( + # f"!!! NOW OutputGate has messages: {self._context.get_messages_for_logging()}" + # ) + self.open_gate() for frame, direction in self._frames_buffer: await self.push_frame(frame, direction) self._frames_buffer = [] except asyncio.CancelledError: break + except Exception as e: + logger.error(f"!!! OutputGate error: {e}") + raise e + break + + +class ConversationAudioContextAssembler(FrameProcessor): + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(**kwargs) + self._context = context + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We must not block system frames. + if isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + return + + if isinstance(frame, OpenAILLMContextFrame): + GoogleLLMContext.upgrade_to_google(self._context) + last_message = frame.context.messages[-1] + self._context._messages.append(last_message) + logger.debug( + f"!!! ConversationAudioContextAssembler {self._context.get_messages_for_logging()}" + ) + await self.push_frame(OpenAILLMContextFrame(context=self._context)) + + +class UserAggregatorBuffer(LLMResponseAggregator): + def __init__(self, **kwargs): + super().__init__( + messages=None, + role=None, + start_frame=LLMFullResponseStartFrame, + end_frame=LLMFullResponseEndFrame, + accumulator_frame=TextFrame, + handle_interruptions=True, + expect_stripped_words=False, + ) + self._transcription = "" + + async def _push_aggregation(self): + if self._aggregation: + self._transcription = self._aggregation + self._aggregation = "" + + logger.debug(f"!!! UserAggregatorBuffer: {self._transcription}") + + async def wait_for_transcription(self): + while not self._transcription: + await asyncio.sleep(0.01) + tx = self._transcription + self._transcription = "" + return tx async def main(): @@ -565,25 +610,30 @@ async def main(): voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady ) - # This is the LLM that will classify and transcribe user speech. - tx_and_cl_llm = GoogleLLMService( + # This is the LLM that will transcribe user speech. + tx_llm = GoogleLLMService( + name="Transcriber", model="gemini-2.0-flash-exp", api_key=os.getenv("GOOGLE_API_KEY"), - tools=tx_and_cl_tools, temperature=0.0, - tool_config={ - "function_calling_config": { - "mode": "ANY", - "allowed_function_names": ["transcription_and_classification_output"], - }, - }, + system_instruction=transcriber_system_instruction, + ) + + # This is the LLM that will classify user speech as complete or incomplete. + classifier_llm = GoogleLLMService( + name="Classifier", + model="gemini-2.0-flash-exp", + api_key=os.getenv("GOOGLE_API_KEY"), + temperature=0.0, + system_instruction=classifier_system_instruction, ) # This is the regular LLM that responds conversationally. conversation_llm = GoogleLLMService( + name="Conversation", model="gemini-2.0-flash-exp", api_key=os.getenv("GOOGLE_API_KEY"), - system_instruction=conversational_system_message, + system_instruction=conversation_system_instruction, ) context = OpenAILLMContext() @@ -602,10 +652,11 @@ async def main(): # as complete or incomplete. # statement_judge_context_filter = StatementJudgeAudioContextAccumulator(notifier=notifier) + audio_accumulater = AudioAccumulator() # This sends a UserStoppedSpeakingFrame and triggers the notifier event - # completeness_check = CompletenessCheck( - # notifier=notifier, audio_accumulator=statement_judge_context_filter - # ) + completeness_check = CompletenessCheck( + notifier=notifier, audio_accumulator=audio_accumulater + ) # # Notify if the user hasn't said anything. async def user_idle_notifier(frame): @@ -615,8 +666,6 @@ async def main(): # sentence, this will wake up the notifier if that happens. user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0) - bot_output_gate = OutputGate(notifier=notifier) - async def block_user_stopped_speaking(frame): return not isinstance(frame, UserStoppedSpeakingFrame) @@ -628,10 +677,18 @@ async def main(): or isinstance(frame, StopInterruptionFrame) ) + conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context) + + user_aggregator_buffer = UserAggregatorBuffer() + + bot_output_gate = OutputGate( + notifier=notifier, context=context, user_transcription_buffer=user_aggregator_buffer + ) + pipeline = Pipeline( [ transport.input(), - AudioAccumulator(), + audio_accumulater, ParallelPipeline( [ # Pass everything except UserStoppedSpeaking to the elements after @@ -639,23 +696,31 @@ async def main(): FunctionFilter(filter=block_user_stopped_speaking), ], [ - # cl_and_tx_context_creator, - tx_and_cl_llm, - # completeness_check, - # context_aggregator.user(), + ParallelPipeline( + [ + classifier_llm, + completeness_check, + ], + [ + tx_llm, + user_aggregator_buffer, + ], + ) + ], + [ + # Block everything except OpenAILLMContextFrame and LLMMessagesFrame + # FunctionFilter(filter=pass_only_llm_trigger_frames), + conversation_audio_context_assembler, + conversation_llm, + bot_output_gate, # buffer output until notified. + # TempPrinter(), ], - # [ - # # Block everything except OpenAILLMContextFrame and LLMMessagesFrame - # # FunctionFilter(filter=pass_only_llm_trigger_frames), - # audio_input_context_creator, - # llm, - # bot_output_gate, # Buffer all llm/tts output until notified. - # ], ), - # tts, - # user_idle, - # transport.output(), - # context_aggregator.assistant(), + # wherefore art thou, user context aggregator? + tts, + user_idle, + transport.output(), + context_aggregator.assistant(), ], )