From 632bae7eee590027acd94473d5d3cb0dda6ec434 Mon Sep 17 00:00:00 2001 From: James Hush Date: Wed, 27 Nov 2024 12:21:45 +0800 Subject: [PATCH] Interrupted? --- examples/foundational/race_bot.py | 64 ++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/examples/foundational/race_bot.py b/examples/foundational/race_bot.py index 3b24cc7d7..62c26abd3 100644 --- a/examples/foundational/race_bot.py +++ b/examples/foundational/race_bot.py @@ -14,6 +14,9 @@ from loguru import logger from runner import configure from pipecat.frames.frames import ( + BotSpeakingFrame, + EndFrame, + Frame, StartInterruptionFrame, StopInterruptionFrame, TranscriptionFrame, @@ -24,13 +27,29 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.services.cartesia import CartesiaTTSService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport +logger.remove(0) logger.add(sys.stderr, level="DEBUG") +class DebugProcessor(FrameProcessor): + def __init__(self, name, **kwargs): + self._name = name + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if not ( + isinstance(frame, BotSpeakingFrame) + ): + logger.debug(f"--- {self._name}: {frame} {direction}") + await self.push_frame(frame, direction) + + async def main(): async with aiohttp.ClientSession() as session: (room_url, _) = await configure(session) @@ -53,6 +72,8 @@ async def main(): }, ] + dp = DebugProcessor("dp") + context = OpenAILLMContext(messages) context_aggregator = llm.create_context_aggregator(context) @@ -61,6 +82,7 @@ async def main(): task = PipelineTask( Pipeline( [ + dp, context_aggregator.user(), llm, tts, @@ -74,23 +96,39 @@ async def main(): # participant joins. @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): - # Create frames for 3 seconds - start_time = time.time() - while time.time() - start_time < 300: - timestamp = time.time() - frames = [ + participant_id = participant.get("info", {}).get("participantId", "") + + await task.queue_frames( + [ UserStartedSpeakingFrame(), - TranscriptionFrame("Tell a joke about dogs.", "user_id", timestamp), + TranscriptionFrame("Tell a joke about dogs.", participant_id, time.time()), UserStoppedSpeakingFrame(), ] - await task.queue_frames(frames) + ) + # await asyncio.sleep(5) # Small delay between frame sets + + # Create frames for 60 seconds + start_time = time.time() + while time.time() - start_time < 30: + elapsed_time = round(time.time() - start_time) + logger.info(f"Running for {elapsed_time} seconds") await asyncio.sleep(5) # Small delay between frame sets - next_frames = [ - StartInterruptionFrame(), - TranscriptionFrame("Tell a joke about cats.", "user_id", timestamp), - StopInterruptionFrame(), - ] - await task.queue_frames(next_frames) + await task.queue_frames( + [ + StartInterruptionFrame(), + TranscriptionFrame("Tell a joke about cats.", participant_id, time.time()), + StopInterruptionFrame(), + ] + ) + await asyncio.sleep(5) # Small delay between frame sets + await task.queue_frames( + [ + StartInterruptionFrame(), + TranscriptionFrame("Tell a joke about dogs.", participant_id, time.time()), + StopInterruptionFrame(), + ] + ) + await task.queue_frame(EndFrame()) await runner.run(task)