Enhance text input processing and LLM interaction in the backend

- Refactor TextInputProcessor to handle immediate and silent text inputs, improving user experience during voice interactions.
- Introduce PassthroughLLMAssistantAggregator to manage LLM responses while preserving context for downstream TTS processing.
- Update event handling for text input and client readiness, ensuring timely updates to the conversation context.
- Modify run_pipeline to integrate new aggregators and streamline message handling, enhancing overall pipeline efficiency.
- Improve message ordering in useVoicePreview to ensure accurate display of chat messages based on timestamps.
This commit is contained in:
Xin Wang
2026-06-14 22:12:56 +08:00
parent 86d9acce78
commit b749d2e075
2 changed files with 107 additions and 17 deletions

View File

@@ -13,17 +13,21 @@ from services.pipecat.service_factory import create_services
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
EndFrame,
InputTextRawFrame,
InputTransportMessageFrame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
OutputTransportMessageUrgentFrame,
TextFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMAssistantAggregator,
LLMUserAggregator,
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -57,11 +61,21 @@ def _text_input(message) -> tuple[str, bool] | None:
class TextInputProcessor(FrameProcessor):
"""把 transport 文字消息转换成级联与实时 LLM 都能消费的帧。"""
"""把 transport 文字消息转换成 LLM 消费的帧。
run_immediately(默认/打断):先通过 on_text_input 事件把用户文字交给
run_pipeline 登记,再用 broadcast_interruption() 打断当前播报。新的 LLM
回复由 assistant aggregator 确认处理完 interruption 后触发。
run_immediately=False(RTVI send-text 静默追加):仅把文字写进上下文,
不打断、不触发推理。
"""
def __init__(self):
super().__init__()
# 立即触发的文字(含打断语义)走 on_text_input;静默追加另走一条事件
self._register_event_handler("on_text_input")
self._register_event_handler("on_text_append")
self._register_event_handler("on_client_ready")
async def process_frame(self, frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -70,6 +84,10 @@ class TextInputProcessor(FrameProcessor):
await self.push_frame(frame, direction)
return
if isinstance(frame.message, dict) and frame.message.get("type") == "client-ready":
await self._call_event_handler("on_client_ready")
return
parsed = _text_input(frame.message)
if not parsed:
await self.push_frame(frame, direction)
@@ -77,17 +95,33 @@ class TextInputProcessor(FrameProcessor):
text, run_immediately = parsed
if run_immediately:
# 先登记文字再打断。下一轮 LLM 由 assistant aggregator 在真正处理完
# InterruptionFrame 后触发,避免新回复被这次 interruption 一起取消。
await self._call_event_handler("on_text_input", text)
await self.broadcast_interruption()
else:
await self._call_event_handler("on_text_append", text)
await self.push_frame(
LLMMessagesAppendFrame(
messages=[{"role": "user", "content": text}],
run_llm=run_immediately,
)
)
if run_immediately:
await self.push_frame(InputTextRawFrame(text=text))
await self._call_event_handler("on_text_input", text)
class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
"""聚合 LLM 回复进上下文,同时继续把回复帧交给下游 TTS。"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_event_handler("on_interruption_processed")
async def process_frame(self, frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# LLMAssistantAggregator 默认会消费这些帧。放在 TTS 前用于中断时保存
# 已生成前缀时,必须显式透传,否则 TTS 收不到任何 LLM 回复。
if isinstance(
frame,
(LLMFullResponseStartFrame, LLMFullResponseEndFrame, TextFrame),
):
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
await self._call_event_handler("on_interruption_processed")
async def run_pipeline(transport, cfg: AssistantConfig) -> None:
@@ -103,9 +137,9 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
stt, llm, tts = create_services(cfg)
context = LLMContext(messages=[{"role": "system", "content": cfg.prompt}])
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
user_aggregator = LLMUserAggregator(
context,
user_params=LLMUserAggregatorParams(
params=LLMUserAggregatorParams(
vad_analyzer=SileroVADAnalyzer(),
user_turn_strategies=UserTurnStrategies(
start=[
@@ -117,6 +151,7 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
),
),
)
assistant_aggregator = PassthroughLLMAssistantAggregator(context)
text_input = TextInputProcessor()
pipeline = Pipeline(
@@ -126,9 +161,12 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
stt,
user_aggregator,
llm,
# Aggregate the streamed LLM text before TTS. On interruption,
# Pipecat commits the generated prefix immediately instead of
# waiting for a TTS provider to emit spoken-text/timestamp frames.
assistant_aggregator,
tts,
transport.output(),
assistant_aggregator,
]
)
@@ -153,22 +191,58 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
)
)
greeting_transcript_sent = False
pending_text_inputs: list[str] = []
async def append_user_text_to_context(text: str, *, run_llm: bool) -> None:
await worker.queue_frame(
LLMMessagesAppendFrame(
messages=[{"role": "user", "content": text}],
run_llm=run_llm,
)
)
@user_aggregator.event_handler("on_user_turn_stopped")
async def on_user_turn_stopped(_aggregator, _strategy, message):
await queue_transcript("user", message.content, message.timestamp)
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
async def on_assistant_turn_stopped(_aggregator, message):
# 助手半句此刻已写入上下文,上报为 transcript
await queue_transcript("assistant", message.content, message.timestamp)
@text_input.event_handler("on_text_input")
async def on_text_input(_processor, text):
pending_text_inputs.append(text)
# 前端显示不依赖 interruption 后续事件,必须在打断前先排入发送队列。
await queue_transcript("user", text, time_now_iso8601())
@assistant_aggregator.event_handler("on_interruption_processed")
async def on_interruption_processed(_aggregator):
if not pending_text_inputs:
return
text = pending_text_inputs.pop(0)
# assistant aggregator 已处理完 interruption,现在再启动下一轮 LLM。
await append_user_text_to_context(text, run_llm=True)
@text_input.event_handler("on_text_append")
async def on_text_append(_processor, text):
# 静默追加:写进上下文但不打断、不触发推理;transcript 照常上报
await queue_transcript("user", text, time_now_iso8601())
await append_user_text_to_context(text, run_llm=False)
@text_input.event_handler("on_client_ready")
async def on_client_ready(_processor):
nonlocal greeting_transcript_sent
if cfg.greeting and not greeting_transcript_sent:
greeting_transcript_sent = True
await queue_transcript("assistant", cfg.greeting, time_now_iso8601())
@transport.event_handler("on_client_connected")
async def on_client_connected(_transport, _client):
if cfg.greeting:
await worker.queue_frame(TTSSpeakFrame(cfg.greeting))
context.add_message({"role": "assistant", "content": cfg.greeting})
await worker.queue_frame(TTSSpeakFrame(cfg.greeting, append_to_context=False))
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(_transport, _client):

View File

@@ -54,6 +54,11 @@ function errorMessage(error: unknown, fallback: string): string {
return fallback;
}
function messageOrder(message: ChatMessage): number {
const timestamp = Date.parse(message.timestamp);
return Number.isNaN(timestamp) ? Number.MAX_SAFE_INTEGER : timestamp;
}
function microphoneErrorMessage(error: unknown): string {
if (error instanceof DOMException) {
if (error.name === "NotAllowedError") {
@@ -137,6 +142,7 @@ export function useVoicePreview(assistantId: string | null) {
const channel = dataChannelRef.current;
dataChannelRef.current = null;
if (channel) {
channel.onopen = null;
channel.onmessage = null;
channel.close();
}
@@ -290,6 +296,9 @@ export function useVoicePreview(assistantId: string | null) {
// 由浏览器侧主动创建,后端 SmallWebRTCConnection 的 on("datachannel") 会接住。
const channel = pc.createDataChannel("chat");
dataChannelRef.current = channel;
channel.onopen = () => {
channel.send(JSON.stringify({ type: "client-ready" }));
};
channel.onmessage = (event) => {
try {
const msg = JSON.parse(event.data);
@@ -309,7 +318,14 @@ export function useVoicePreview(assistantId: string | null) {
? msg.timestamp
: new Date().toISOString(),
};
setMessages((prev) => [...prev, next]);
setMessages((prev) =>
[...prev, next].sort(
(a, b) =>
messageOrder(a) - messageOrder(b) ||
Number(a.id.replace("msg-", "")) -
Number(b.id.replace("msg-", "")),
),
);
}
} catch {
/* 非 JSON / 未知消息,忽略 */