From d6ef3d64ace855238e0ee60c5e5d92247b8a7448 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 30 Apr 2025 21:40:40 -0400 Subject: [PATCH] [WIP] AWS Nova Sonic service - fix context problems of double-counting LLM text, and mis-categorizing user text as LLM text --- .../services/aws_nova_sonic/context.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/pipecat/services/aws_nova_sonic/context.py b/src/pipecat/services/aws_nova_sonic/context.py index 5d9bafec5..4e2a4fcc1 100644 --- a/src/pipecat/services/aws_nova_sonic/context.py +++ b/src/pipecat/services/aws_nova_sonic/context.py @@ -16,6 +16,8 @@ from pipecat.frames.frames import ( FunctionCallResultFrame, LLMMessagesUpdateFrame, LLMSetToolsFrame, + LLMTextFrame, + TranscriptionFrame, ) from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection @@ -132,6 +134,23 @@ class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator): class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator): + # AWS Nova Sonic is a speech-to-speech model. + # It behaves like a combined STT + LLM + TTS service, emitting all of: + # - TranscriptionFrame (for user text) + # - LLMTextFrame (for assistant text) + # - TTSTextFrame (for assistant text) + # In a "standard" pipeline (with separate STT + LLM + TTS services): + # - The TranscriptionFrame is swallowed by the LLMUserContextAggregator + # - The LLMTextFrame is swallowed by the TTS service + # Meaning the LLMAssistantContextAggregator only receives the TTSTextFrames. It actually + # implicitly assumes it will receive only *non-duplicate* *assistant-related* text frames, and + # will misbehave otherwise (double-counting assistant text, or mis-categorizing user text as + # assistant text). + # So, let's override process_frame here to ignore TranscriptionFrames and LLMTextFrames. + async def process_frame(self, frame: Frame, direction: FrameDirection): + if not isinstance(frame, (LLMTextFrame, TranscriptionFrame)): + await super().process_frame(frame, direction) + async def handle_function_call_result(self, frame: FunctionCallResultFrame): await super().handle_function_call_result(frame)