fix text input no context update

This commit is contained in:
Xin Wang
2026-05-22 21:30:54 +08:00
parent 7267c06552
commit 1ffc716575
2 changed files with 41 additions and 11 deletions

View File

@@ -14,7 +14,6 @@ from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import ( from pipecat.processors.aggregators.llm_response_universal import (
AssistantTurnStoppedMessage,
LLMContextAggregatorPair, LLMContextAggregatorPair,
LLMUserAggregatorParams, LLMUserAggregatorParams,
UserTurnStoppedMessage, UserTurnStoppedMessage,
@@ -34,7 +33,7 @@ from .config import EngineConfig
from .product_protocol import ProductWebsocketSerializer from .product_protocol import ProductWebsocketSerializer
from .services import create_llm_service, create_stt_service, create_tts_service from .services import create_llm_service, create_stt_service, create_tts_service
from .text_input import ProductTextInputProcessor from .text_input import ProductTextInputProcessor
from .text_stream import ProductTextStreamProcessor from .text_stream import ProductAssistantTurnStoppedMessage, ProductTextStreamProcessor
from .transcript_stream import ProductTranscriptStreamProcessor from .transcript_stream import ProductTranscriptStreamProcessor
from .turn_start import InterruptionGateUserTurnStartStrategy from .turn_start import InterruptionGateUserTurnStartStrategy
@@ -118,13 +117,14 @@ async def run_pipeline_with_serializer(
), ),
], ],
) )
user_aggregator, assistant_aggregator = LLMContextAggregatorPair( user_aggregator, _ = LLMContextAggregatorPair(
context, context,
user_params=LLMUserAggregatorParams( user_params=LLMUserAggregatorParams(
vad_analyzer=SileroVADAnalyzer(params=vad_params), vad_analyzer=SileroVADAnalyzer(params=vad_params),
user_turn_strategies=user_turn_strategies, user_turn_strategies=user_turn_strategies,
), ),
) )
text_stream = ProductTextStreamProcessor(context)
pipeline = Pipeline( pipeline = Pipeline(
[ [
@@ -134,10 +134,9 @@ async def run_pipeline_with_serializer(
ProductTranscriptStreamProcessor(), ProductTranscriptStreamProcessor(),
user_aggregator, user_aggregator,
llm, llm,
ProductTextStreamProcessor(), text_stream,
tts, tts,
transport.output(), transport.output(),
assistant_aggregator,
] ]
) )
@@ -188,11 +187,10 @@ async def run_pipeline_with_serializer(
) )
) )
# NOTE: assistant turn started/final events are emitted by @text_stream.event_handler("on_assistant_turn_stopped")
# ProductTextStreamProcessor, upstream of TTS, so text streams to the async def on_assistant_turn_stopped(
# client ahead of audio. This logger is kept for server-side visibility. _aggregator, message: ProductAssistantTurnStoppedMessage
@assistant_aggregator.event_handler("on_assistant_turn_stopped") ):
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
logger.info(f"Assistant: {message.content}") logger.info(f"Assistant: {message.content}")
runner = PipelineRunner(handle_sigint=False) runner = PipelineRunner(handle_sigint=False)

View File

@@ -1,5 +1,9 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from loguru import logger
from pipecat.frames.frames import ( from pipecat.frames.frames import (
Frame, Frame,
InterruptionFrame, InterruptionFrame,
@@ -9,7 +13,16 @@ from pipecat.frames.frames import (
OutputTransportMessageUrgentFrame, OutputTransportMessageUrgentFrame,
TTSSpeakFrame, TTSSpeakFrame,
) )
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.time import time_now_iso8601
@dataclass
class ProductAssistantTurnStoppedMessage:
content: str
interrupted: bool
timestamp: str
class ProductTextStreamProcessor(FrameProcessor): class ProductTextStreamProcessor(FrameProcessor):
@@ -29,10 +42,13 @@ class ProductTextStreamProcessor(FrameProcessor):
started/delta/final sequence for its fixed text. started/delta/final sequence for its fixed text.
""" """
def __init__(self) -> None: def __init__(self, context: LLMContext | None = None) -> None:
super().__init__() super().__init__()
self._context = context
self._aggregation: list[str] = [] self._aggregation: list[str] = []
self._turn_active = False self._turn_active = False
self._turn_start_timestamp = ""
self._register_event_handler("on_assistant_turn_stopped")
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
await super().process_frame(frame, direction) await super().process_frame(frame, direction)
@@ -62,6 +78,7 @@ class ProductTextStreamProcessor(FrameProcessor):
return return
self._turn_active = True self._turn_active = True
self._aggregation = [] self._aggregation = []
self._turn_start_timestamp = time_now_iso8601()
await self._emit("response.text.started") await self._emit("response.text.started")
async def _delta(self, text: str) -> None: async def _delta(self, text: str) -> None:
@@ -78,11 +95,26 @@ class ProductTextStreamProcessor(FrameProcessor):
full_text = "".join(self._aggregation) full_text = "".join(self._aggregation)
self._turn_active = False self._turn_active = False
self._aggregation = [] self._aggregation = []
if self._context and full_text:
self._context.add_message({"role": "assistant", "content": full_text})
logger.info(
"Assistant committed to LLM context before TTS: "
f"{full_text[:120]}"
)
await self._emit( await self._emit(
"response.text.final", "response.text.final",
text=full_text, text=full_text,
interrupted=interrupted, interrupted=interrupted,
) )
await self._call_event_handler(
"on_assistant_turn_stopped",
ProductAssistantTurnStoppedMessage(
content=full_text,
interrupted=interrupted,
timestamp=self._turn_start_timestamp or time_now_iso8601(),
),
)
self._turn_start_timestamp = ""
async def _emit(self, event_type: str, **payload: object) -> None: async def _emit(self, event_type: str, **payload: object) -> None:
await self.push_frame( await self.push_frame(