From 1ffc71657504d43effd36bc1a4e9685792ad99a7 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Fri, 22 May 2026 21:30:54 +0800 Subject: [PATCH] fix text input no context update --- engine/pipeline.py | 18 ++++++++---------- engine/text_stream.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/engine/pipeline.py b/engine/pipeline.py index 1dcb6b7..7eb2ae5 100644 --- a/engine/pipeline.py +++ b/engine/pipeline.py @@ -14,7 +14,6 @@ from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.llm_response_universal import ( - AssistantTurnStoppedMessage, LLMContextAggregatorPair, LLMUserAggregatorParams, UserTurnStoppedMessage, @@ -34,7 +33,7 @@ from .config import EngineConfig from .product_protocol import ProductWebsocketSerializer from .services import create_llm_service, create_stt_service, create_tts_service from .text_input import ProductTextInputProcessor -from .text_stream import ProductTextStreamProcessor +from .text_stream import ProductAssistantTurnStoppedMessage, ProductTextStreamProcessor from .transcript_stream import ProductTranscriptStreamProcessor from .turn_start import InterruptionGateUserTurnStartStrategy @@ -118,13 +117,14 @@ async def run_pipeline_with_serializer( ), ], ) - user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + user_aggregator, _ = LLMContextAggregatorPair( context, user_params=LLMUserAggregatorParams( vad_analyzer=SileroVADAnalyzer(params=vad_params), user_turn_strategies=user_turn_strategies, ), ) + text_stream = ProductTextStreamProcessor(context) pipeline = Pipeline( [ @@ -134,10 +134,9 @@ async def run_pipeline_with_serializer( ProductTranscriptStreamProcessor(), user_aggregator, llm, - ProductTextStreamProcessor(), + text_stream, tts, transport.output(), - assistant_aggregator, ] ) @@ -188,11 +187,10 @@ async def run_pipeline_with_serializer( ) ) - # NOTE: assistant turn started/final events are emitted by - # ProductTextStreamProcessor, upstream of TTS, so text streams to the - # client ahead of audio. This logger is kept for server-side visibility. - @assistant_aggregator.event_handler("on_assistant_turn_stopped") - async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage): + @text_stream.event_handler("on_assistant_turn_stopped") + async def on_assistant_turn_stopped( + _aggregator, message: ProductAssistantTurnStoppedMessage + ): logger.info(f"Assistant: {message.content}") runner = PipelineRunner(handle_sigint=False) diff --git a/engine/text_stream.py b/engine/text_stream.py index e2e997e..d3f8882 100644 --- a/engine/text_stream.py +++ b/engine/text_stream.py @@ -1,5 +1,9 @@ from __future__ import annotations +from dataclasses import dataclass + +from loguru import logger + from pipecat.frames.frames import ( Frame, InterruptionFrame, @@ -9,7 +13,16 @@ from pipecat.frames.frames import ( OutputTransportMessageUrgentFrame, TTSSpeakFrame, ) +from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.time import time_now_iso8601 + + +@dataclass +class ProductAssistantTurnStoppedMessage: + content: str + interrupted: bool + timestamp: str class ProductTextStreamProcessor(FrameProcessor): @@ -29,10 +42,13 @@ class ProductTextStreamProcessor(FrameProcessor): started/delta/final sequence for its fixed text. """ - def __init__(self) -> None: + def __init__(self, context: LLMContext | None = None) -> None: super().__init__() + self._context = context self._aggregation: list[str] = [] self._turn_active = False + self._turn_start_timestamp = "" + self._register_event_handler("on_assistant_turn_stopped") async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: await super().process_frame(frame, direction) @@ -62,6 +78,7 @@ class ProductTextStreamProcessor(FrameProcessor): return self._turn_active = True self._aggregation = [] + self._turn_start_timestamp = time_now_iso8601() await self._emit("response.text.started") async def _delta(self, text: str) -> None: @@ -78,11 +95,26 @@ class ProductTextStreamProcessor(FrameProcessor): full_text = "".join(self._aggregation) self._turn_active = False self._aggregation = [] + if self._context and full_text: + self._context.add_message({"role": "assistant", "content": full_text}) + logger.info( + "Assistant committed to LLM context before TTS: " + f"{full_text[:120]}" + ) await self._emit( "response.text.final", text=full_text, interrupted=interrupted, ) + await self._call_event_handler( + "on_assistant_turn_stopped", + ProductAssistantTurnStoppedMessage( + content=full_text, + interrupted=interrupted, + timestamp=self._turn_start_timestamp or time_now_iso8601(), + ), + ) + self._turn_start_timestamp = "" async def _emit(self, event_type: str, **payload: object) -> None: await self.push_frame(