diff --git a/examples/foundational/22b-natural-conversation-proposal.py b/examples/foundational/22b-natural-conversation-proposal.py index 3e1d11c4b..e0a7f71ba 100644 --- a/examples/foundational/22b-natural-conversation-proposal.py +++ b/examples/foundational/22b-natural-conversation-proposal.py @@ -4,10 +4,11 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio import aiohttp +import asyncio import os import sys +import time from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.frames.frames import LLMMessagesFrame, TextFrame @@ -15,21 +16,29 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.gated_openai_llm_context import GatedOpenAILLMContextAggregator from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, - OpenAILLMContextFrame, ) -from pipecat.processors.filters.null_filter import NullFilter -from pipecat.processors.filters.wake_notifier_filter import WakeNotifierFilter -from pipecat.processors.user_idle_processor import UserIdleProcessor from pipecat.services.cartesia import CartesiaTTSService from pipecat.services.deepgram import DeepgramSTTService from pipecat.services.openai import OpenAILLMService from pipecat.sync.event_notifier import EventNotifier from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.frames.frames import Frame +from pipecat.processors.frame_processor import FrameProcessor, FrameDirection +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, + StartFrame, + StartInterruptionFrame, + SystemFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame +from pipecat.sync.base_notifier import BaseNotifier +from pipecat.processors.filters.function_filter import FunctionFilter from runner import configure @@ -44,6 +53,144 @@ logger.remove(0) logger.add(sys.stderr, level="DEBUG") +classifier_statement = "Determine if the user's statement ends with a complete sentence or question. The user text is transcribed speech. It may contain multiple fragments concatentated together. Categorize the text as either complete with the user now expecting a response, or incomplete. Return 'YES' if text is likely complete and the user is expecting a response. Return 'NO' if the text seems to be a partial expression or unfinished thought." + + +class StatementJudgeContextFilter(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + # We must not block system frames. + if isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + return + + # We only want to handle OpenAILLMContextFrames, and only want to push a simple + # messages frame that contains a system prompt and the most recent user messages, + # concatenated. + if isinstance(frame, OpenAILLMContextFrame): + logger.debug(f"Context Frame: {frame}") + # Take text content from the most recent user messages. + messages = frame.context.messages + user_text_messages = [] + last_assistant_message = None + for message in reversed(messages): + if message["role"] != "user": + if message["role"] == "assistant": + last_assistant_message = message + break + if isinstance(message["content"], str): + user_text_messages.append(message["content"]) + elif isinstance(message["content"], list): + for content in message["content"]: + if content["type"] == "text": + user_text_messages.append(content["text"]) + # If we have any user text content, push an LLMMessagesFrame + if user_text_messages: + logger.debug(f"User text messages: {user_text_messages}") + user_message = " ".join(reversed(user_text_messages)) + logger.debug(f"User message: {user_message}") + messages = [ + { + "role": "system", + "content": classifier_statement, + } + ] + if last_assistant_message: + messages.append(last_assistant_message) + messages.append({"role": "user", "content": user_message}) + await self.push_frame(LLMMessagesFrame(messages)) + + +class CompletenessCheck(FrameProcessor): + def __init__(self, complete_notifier: BaseNotifier, incomplete_notifier: BaseNotifier): + super().__init__() + self._complete_notifier = complete_notifier + self._incomplete_notifier = incomplete_notifier + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, TextFrame) and frame.text == "YES": + logger.debug("Completeness check YES") + await self.push_frame(UserStoppedSpeakingFrame()) + await self._complete_notifier.notify() + elif isinstance(frame, TextFrame) and frame.text == "NO": + logger.debug("Completeness check NO") + await self._incomplete_notifier.notify() + + +class OutputGate(FrameProcessor): + def __init__( + self, complete_notifier: BaseNotifier, incomplete_notifier: BaseNotifier, **kwargs + ): + super().__init__(**kwargs) + self._gate_open = False + self._frames_buffer = [] + self._complete_notifier = complete_notifier + self._incomplete_notifier = incomplete_notifier + + def close_gate(self): + self._gate_open = False + + def open_gate(self): + self._gate_open = True + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We must not block system frames. + if isinstance(frame, SystemFrame): + if isinstance(frame, StartFrame): + await self._start() + if isinstance(frame, (EndFrame, CancelFrame)): + await self._stop() + if isinstance(frame, StartInterruptionFrame): + self._frames_buffer = [] + self.close_gate() + await self.push_frame(frame, direction) + return + + # Ignore frames that are not following the direction of this gate. + if direction != FrameDirection.DOWNSTREAM: + await self.push_frame(frame, direction) + return + + if self._gate_open: + await self.push_frame(frame, direction) + return + + self._frames_buffer.append((frame, direction)) + + async def _start(self): + self._frames_buffer = [] + self._gate_task = self.get_event_loop().create_task(self._gate_task_handler()) + self._interrupt_task = self.get_event_loop().create_task(self._interrupt_task_handler()) + + async def _stop(self): + self._gate_task.cancel() + await self._gate_task + + async def _gate_task_handler(self): + while True: + try: + await self._complete_notifier.wait() + self.open_gate() + for frame, direction in self._frames_buffer: + await self.push_frame(frame, direction) + self._frames_buffer = [] + except asyncio.CancelledError: + break + + async def _interrupt_task_handler(self): + while True: + try: + await self._incomplete_notifier.wait() + await self.push_frame(StartInterruptionFrame(), FrameDirection.UPSTREAM) + self._frames_buffer = [] + + except asyncio.CancelledError: + break + + async def main(): async with aiohttp.ClientSession() as session: (room_url, _) = await configure(session) @@ -69,20 +216,9 @@ async def main(): # This is the LLM that will be used to detect if the user has finished a # statement. This doesn't really need to be an LLM, we could use NLP - # libraries for that, but it was easier as an example because we - # leverage the context aggregators. + # libraries for that, but we have the machinery to use an LLM, so we might as well! statement_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") - statement_messages = [ - { - "role": "system", - "content": "Determine if the user's statement is a complete sentence or question, ending in a natural pause or punctuation. Return 'YES' if it is complete and 'NO' if it seems to leave a thought unfinished.", - }, - ] - - statement_context = OpenAILLMContext(statement_messages) - statement_context_aggregator = statement_llm.create_context_aggregator(statement_context) - # This is the regular LLM. llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") @@ -105,79 +241,58 @@ async def main(): # This is a notifier that we use to synchronize the two LLMs. notifier = EventNotifier() + # rename/comment? + interrupt_notifier = EventNotifier() - # This a filter that will wake up the notifier if the given predicate - # (wake_check_filter) returns true. - completeness_check = WakeNotifierFilter( - notifier, types=(TextFrame,), filter=wake_check_filter + # This sends a UserStoppedSpeakingFrame and triggers the notifier event + completeness_check = CompletenessCheck( + complete_notifier=notifier, incomplete_notifier=interrupt_notifier ) - # This processor keeps the last context and will let it through once the - # notifier is woken up. - gated_context_aggregator = GatedOpenAILLMContextAggregator(notifier) + # # Notify if the user hasn't said anything. + # async def user_idle_notifier(frame): + # await notifier.notify() - # Notify if the user hasn't said anything. - async def user_idle_notifier(frame): - await notifier.notify() + # # Sometimes the LLM will fail detecting if a user has completed a + # # sentence, this will wake up the notifier if that happens. + # user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=10.0) - # Sometimes the LLM will fail detecting if a user has completed a - # sentence, this will wake up the notifier if that happens. - user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=3.0) + bot_output_gate = OutputGate( + complete_notifier=notifier, incomplete_notifier=interrupt_notifier + ) - class StatementJudgeContextFilter(FrameProcessor): - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, OpenAILLMContextFrame): - logger.debug(f"Context Frame: {frame}") - await self.push_frame(frame, direction) + async def block_user_stopped_speaking(frame): + return not isinstance(frame, UserStoppedSpeakingFrame) - class GatedTTSOutput(FrameProcessor): - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - await self.push_frame(frame, direction) + async def pass_only_llm_trigger_frames(frame): + return isinstance(frame, OpenAILLMContextFrame) or isinstance(frame, LLMMessagesFrame) - # The ParallePipeline input are the user transcripts. We have two - # contexts. The first one will be used to determine if the user finished - # a statement and if so the notifier will be woken up. The second - # context is simply the regular context but it's gated waiting for the - # notifier to be woken up. pipeline = Pipeline( [ transport.input(), stt, - ParallelPipeline( - [ - statement_context_aggregator.user(), - statement_llm, - completeness_check, - NullFilter(), - ], - [context_aggregator.user(), gated_context_aggregator, llm], - ), - user_idle, - tts, # TTS - transport.output(), - context_aggregator.assistant(), - ] - ) - - pipeline_x = Pipeline( - [ - transport.input(), - stt, - user_idle, + # user_idle, context_aggregator.user(), ParallelPipeline( [ + # Pass everything except UserStoppedSpeaking to the elements after + # this ParallelPipeline + FunctionFilter(filter=block_user_stopped_speaking), + ], + [ + # Ignore everything except an OpenAILLMContextFrame. Pass a specially constructed + # LLMMessagesFrame to the statement classifier LLM. The only frame this + # sub-pipeline will output is a UserStoppedSpeakingFrame. StatementJudgeContextFilter(), statement_llm, completeness_check, - NullFilter(), ], [ + # Block everything except OpenAILLMContextFrame and LLMMessagesFrame + FunctionFilter(filter=pass_only_llm_trigger_frames), llm, tts, - GatedTTSOutput(), + bot_output_gate, # Buffer all llm/tts output until notified. ], ), transport.output(), @@ -186,8 +301,7 @@ async def main(): ) task = PipelineTask( - # pipeline, - pipeline_x, + pipeline, PipelineParams( allow_interruptions=True, enable_metrics=True, @@ -203,8 +317,23 @@ async def main(): messages.append({"role": "system", "content": "Please introduce yourself to the user."}) await task.queue_frames([LLMMessagesFrame(messages)]) - runner = PipelineRunner() + @transport.event_handler("on_app_message") + async def on_app_message(transport, message, sender): + logger.debug(f"Received app message: {message} - {sender}") + if "message" not in message: + return + await task.queue_frames( + [ + UserStartedSpeakingFrame(), + TranscriptionFrame( + user_id=sender, timestamp=time.time(), text=message["message"] + ), + UserStoppedSpeakingFrame(), + ] + ) + + runner = PipelineRunner() await runner.run(task)