From b749d2e075c650d508d969639a48091312e9c600 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Sun, 14 Jun 2026 22:12:56 +0800 Subject: [PATCH] Enhance text input processing and LLM interaction in the backend - Refactor TextInputProcessor to handle immediate and silent text inputs, improving user experience during voice interactions. - Introduce PassthroughLLMAssistantAggregator to manage LLM responses while preserving context for downstream TTS processing. - Update event handling for text input and client readiness, ensuring timely updates to the conversation context. - Modify run_pipeline to integrate new aggregators and streamline message handling, enhancing overall pipeline efficiency. - Improve message ordering in useVoicePreview to ensure accurate display of chat messages based on timestamps. --- backend/services/pipecat/pipeline.py | 106 ++++++++++++++++++++---- frontend/src/hooks/use-voice-preview.ts | 18 +++- 2 files changed, 107 insertions(+), 17 deletions(-) diff --git a/backend/services/pipecat/pipeline.py b/backend/services/pipecat/pipeline.py index 9a7a6f1..a9648d3 100644 --- a/backend/services/pipecat/pipeline.py +++ b/backend/services/pipecat/pipeline.py @@ -13,17 +13,21 @@ from services.pipecat.service_factory import create_services from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.frames.frames import ( EndFrame, - InputTextRawFrame, InputTransportMessageFrame, + InterruptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, LLMMessagesAppendFrame, OutputTransportMessageUrgentFrame, + TextFrame, TTSSpeakFrame, ) from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.worker import PipelineParams, PipelineWorker from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.llm_response_universal import ( - LLMContextAggregatorPair, + LLMAssistantAggregator, + LLMUserAggregator, LLMUserAggregatorParams, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor @@ -57,11 +61,21 @@ def _text_input(message) -> tuple[str, bool] | None: class TextInputProcessor(FrameProcessor): - """把 transport 文字消息转换成级联与实时 LLM 都能消费的帧。""" + """把 transport 文字消息转换成 LLM 可消费的帧。 + + run_immediately(默认/打断):先通过 on_text_input 事件把用户文字交给 + run_pipeline 登记,再用 broadcast_interruption() 打断当前播报。新的 LLM + 回复由 assistant aggregator 确认处理完 interruption 后触发。 + run_immediately=False(RTVI send-text 静默追加):仅把文字写进上下文, + 不打断、不触发推理。 + """ def __init__(self): super().__init__() + # 立即触发的文字(含打断语义)走 on_text_input;静默追加另走一条事件 self._register_event_handler("on_text_input") + self._register_event_handler("on_text_append") + self._register_event_handler("on_client_ready") async def process_frame(self, frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -70,6 +84,10 @@ class TextInputProcessor(FrameProcessor): await self.push_frame(frame, direction) return + if isinstance(frame.message, dict) and frame.message.get("type") == "client-ready": + await self._call_event_handler("on_client_ready") + return + parsed = _text_input(frame.message) if not parsed: await self.push_frame(frame, direction) @@ -77,17 +95,33 @@ class TextInputProcessor(FrameProcessor): text, run_immediately = parsed if run_immediately: + # 先登记文字再打断。下一轮 LLM 由 assistant aggregator 在真正处理完 + # InterruptionFrame 后触发,避免新回复被这次 interruption 一起取消。 + await self._call_event_handler("on_text_input", text) await self.broadcast_interruption() + else: + await self._call_event_handler("on_text_append", text) - await self.push_frame( - LLMMessagesAppendFrame( - messages=[{"role": "user", "content": text}], - run_llm=run_immediately, - ) - ) - if run_immediately: - await self.push_frame(InputTextRawFrame(text=text)) - await self._call_event_handler("on_text_input", text) + +class PassthroughLLMAssistantAggregator(LLMAssistantAggregator): + """聚合 LLM 回复进上下文,同时继续把回复帧交给下游 TTS。""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._register_event_handler("on_interruption_processed") + + async def process_frame(self, frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # LLMAssistantAggregator 默认会消费这些帧。放在 TTS 前用于中断时保存 + # 已生成前缀时,必须显式透传,否则 TTS 收不到任何 LLM 回复。 + if isinstance( + frame, + (LLMFullResponseStartFrame, LLMFullResponseEndFrame, TextFrame), + ): + await self.push_frame(frame, direction) + elif isinstance(frame, InterruptionFrame): + await self._call_event_handler("on_interruption_processed") async def run_pipeline(transport, cfg: AssistantConfig) -> None: @@ -103,9 +137,9 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None: stt, llm, tts = create_services(cfg) context = LLMContext(messages=[{"role": "system", "content": cfg.prompt}]) - user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + user_aggregator = LLMUserAggregator( context, - user_params=LLMUserAggregatorParams( + params=LLMUserAggregatorParams( vad_analyzer=SileroVADAnalyzer(), user_turn_strategies=UserTurnStrategies( start=[ @@ -117,6 +151,7 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None: ), ), ) + assistant_aggregator = PassthroughLLMAssistantAggregator(context) text_input = TextInputProcessor() pipeline = Pipeline( @@ -126,9 +161,12 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None: stt, user_aggregator, llm, + # Aggregate the streamed LLM text before TTS. On interruption, + # Pipecat commits the generated prefix immediately instead of + # waiting for a TTS provider to emit spoken-text/timestamp frames. + assistant_aggregator, tts, transport.output(), - assistant_aggregator, ] ) @@ -153,22 +191,58 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None: ) ) + greeting_transcript_sent = False + pending_text_inputs: list[str] = [] + + async def append_user_text_to_context(text: str, *, run_llm: bool) -> None: + await worker.queue_frame( + LLMMessagesAppendFrame( + messages=[{"role": "user", "content": text}], + run_llm=run_llm, + ) + ) + @user_aggregator.event_handler("on_user_turn_stopped") async def on_user_turn_stopped(_aggregator, _strategy, message): await queue_transcript("user", message.content, message.timestamp) @assistant_aggregator.event_handler("on_assistant_turn_stopped") async def on_assistant_turn_stopped(_aggregator, message): + # 助手半句此刻已写入上下文,上报为 transcript await queue_transcript("assistant", message.content, message.timestamp) @text_input.event_handler("on_text_input") async def on_text_input(_processor, text): + pending_text_inputs.append(text) + # 前端显示不依赖 interruption 后续事件,必须在打断前先排入发送队列。 await queue_transcript("user", text, time_now_iso8601()) + @assistant_aggregator.event_handler("on_interruption_processed") + async def on_interruption_processed(_aggregator): + if not pending_text_inputs: + return + text = pending_text_inputs.pop(0) + # assistant aggregator 已处理完 interruption,现在再启动下一轮 LLM。 + await append_user_text_to_context(text, run_llm=True) + + @text_input.event_handler("on_text_append") + async def on_text_append(_processor, text): + # 静默追加:写进上下文但不打断、不触发推理;transcript 照常上报 + await queue_transcript("user", text, time_now_iso8601()) + await append_user_text_to_context(text, run_llm=False) + + @text_input.event_handler("on_client_ready") + async def on_client_ready(_processor): + nonlocal greeting_transcript_sent + if cfg.greeting and not greeting_transcript_sent: + greeting_transcript_sent = True + await queue_transcript("assistant", cfg.greeting, time_now_iso8601()) + @transport.event_handler("on_client_connected") async def on_client_connected(_transport, _client): if cfg.greeting: - await worker.queue_frame(TTSSpeakFrame(cfg.greeting)) + context.add_message({"role": "assistant", "content": cfg.greeting}) + await worker.queue_frame(TTSSpeakFrame(cfg.greeting, append_to_context=False)) @transport.event_handler("on_client_disconnected") async def on_client_disconnected(_transport, _client): diff --git a/frontend/src/hooks/use-voice-preview.ts b/frontend/src/hooks/use-voice-preview.ts index 650a8f3..f1631cb 100644 --- a/frontend/src/hooks/use-voice-preview.ts +++ b/frontend/src/hooks/use-voice-preview.ts @@ -54,6 +54,11 @@ function errorMessage(error: unknown, fallback: string): string { return fallback; } +function messageOrder(message: ChatMessage): number { + const timestamp = Date.parse(message.timestamp); + return Number.isNaN(timestamp) ? Number.MAX_SAFE_INTEGER : timestamp; +} + function microphoneErrorMessage(error: unknown): string { if (error instanceof DOMException) { if (error.name === "NotAllowedError") { @@ -137,6 +142,7 @@ export function useVoicePreview(assistantId: string | null) { const channel = dataChannelRef.current; dataChannelRef.current = null; if (channel) { + channel.onopen = null; channel.onmessage = null; channel.close(); } @@ -290,6 +296,9 @@ export function useVoicePreview(assistantId: string | null) { // 由浏览器侧主动创建,后端 SmallWebRTCConnection 的 on("datachannel") 会接住。 const channel = pc.createDataChannel("chat"); dataChannelRef.current = channel; + channel.onopen = () => { + channel.send(JSON.stringify({ type: "client-ready" })); + }; channel.onmessage = (event) => { try { const msg = JSON.parse(event.data); @@ -309,7 +318,14 @@ export function useVoicePreview(assistantId: string | null) { ? msg.timestamp : new Date().toISOString(), }; - setMessages((prev) => [...prev, next]); + setMessages((prev) => + [...prev, next].sort( + (a, b) => + messageOrder(a) - messageOrder(b) || + Number(a.id.replace("msg-", "")) - + Number(b.id.replace("msg-", "")), + ), + ); } } catch { /* 非 JSON / 未知消息,忽略 */