Enhance text input processing and LLM interaction in the backend
- Refactor TextInputProcessor to handle immediate and silent text inputs, improving user experience during voice interactions. - Introduce PassthroughLLMAssistantAggregator to manage LLM responses while preserving context for downstream TTS processing. - Update event handling for text input and client readiness, ensuring timely updates to the conversation context. - Modify run_pipeline to integrate new aggregators and streamline message handling, enhancing overall pipeline efficiency. - Improve message ordering in useVoicePreview to ensure accurate display of chat messages based on timestamps.
This commit is contained in:
@@ -13,17 +13,21 @@ from services.pipecat.service_factory import create_services
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
InputTextRawFrame,
|
||||
InputTransportMessageFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMAssistantAggregator,
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -57,11 +61,21 @@ def _text_input(message) -> tuple[str, bool] | None:
|
||||
|
||||
|
||||
class TextInputProcessor(FrameProcessor):
|
||||
"""把 transport 文字消息转换成级联与实时 LLM 都能消费的帧。"""
|
||||
"""把 transport 文字消息转换成 LLM 可消费的帧。
|
||||
|
||||
run_immediately(默认/打断):先通过 on_text_input 事件把用户文字交给
|
||||
run_pipeline 登记,再用 broadcast_interruption() 打断当前播报。新的 LLM
|
||||
回复由 assistant aggregator 确认处理完 interruption 后触发。
|
||||
run_immediately=False(RTVI send-text 静默追加):仅把文字写进上下文,
|
||||
不打断、不触发推理。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 立即触发的文字(含打断语义)走 on_text_input;静默追加另走一条事件
|
||||
self._register_event_handler("on_text_input")
|
||||
self._register_event_handler("on_text_append")
|
||||
self._register_event_handler("on_client_ready")
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -70,6 +84,10 @@ class TextInputProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
if isinstance(frame.message, dict) and frame.message.get("type") == "client-ready":
|
||||
await self._call_event_handler("on_client_ready")
|
||||
return
|
||||
|
||||
parsed = _text_input(frame.message)
|
||||
if not parsed:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -77,17 +95,33 @@ class TextInputProcessor(FrameProcessor):
|
||||
|
||||
text, run_immediately = parsed
|
||||
if run_immediately:
|
||||
# 先登记文字再打断。下一轮 LLM 由 assistant aggregator 在真正处理完
|
||||
# InterruptionFrame 后触发,避免新回复被这次 interruption 一起取消。
|
||||
await self._call_event_handler("on_text_input", text)
|
||||
await self.broadcast_interruption()
|
||||
else:
|
||||
await self._call_event_handler("on_text_append", text)
|
||||
|
||||
await self.push_frame(
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
run_llm=run_immediately,
|
||||
)
|
||||
)
|
||||
if run_immediately:
|
||||
await self.push_frame(InputTextRawFrame(text=text))
|
||||
await self._call_event_handler("on_text_input", text)
|
||||
|
||||
class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
|
||||
"""聚合 LLM 回复进上下文,同时继续把回复帧交给下游 TTS。"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._register_event_handler("on_interruption_processed")
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# LLMAssistantAggregator 默认会消费这些帧。放在 TTS 前用于中断时保存
|
||||
# 已生成前缀时,必须显式透传,否则 TTS 收不到任何 LLM 回复。
|
||||
if isinstance(
|
||||
frame,
|
||||
(LLMFullResponseStartFrame, LLMFullResponseEndFrame, TextFrame),
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._call_event_handler("on_interruption_processed")
|
||||
|
||||
|
||||
async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
@@ -103,9 +137,9 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
stt, llm, tts = create_services(cfg)
|
||||
|
||||
context = LLMContext(messages=[{"role": "system", "content": cfg.prompt}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
user_aggregator = LLMUserAggregator(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
params=LLMUserAggregatorParams(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
start=[
|
||||
@@ -117,6 +151,7 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
),
|
||||
),
|
||||
)
|
||||
assistant_aggregator = PassthroughLLMAssistantAggregator(context)
|
||||
text_input = TextInputProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
@@ -126,9 +161,12 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
# Aggregate the streamed LLM text before TTS. On interruption,
|
||||
# Pipecat commits the generated prefix immediately instead of
|
||||
# waiting for a TTS provider to emit spoken-text/timestamp frames.
|
||||
assistant_aggregator,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -153,22 +191,58 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
greeting_transcript_sent = False
|
||||
pending_text_inputs: list[str] = []
|
||||
|
||||
async def append_user_text_to_context(text: str, *, run_llm: bool) -> None:
|
||||
await worker.queue_frame(
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
run_llm=run_llm,
|
||||
)
|
||||
)
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(_aggregator, _strategy, message):
|
||||
await queue_transcript("user", message.content, message.timestamp)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(_aggregator, message):
|
||||
# 助手半句此刻已写入上下文,上报为 transcript
|
||||
await queue_transcript("assistant", message.content, message.timestamp)
|
||||
|
||||
@text_input.event_handler("on_text_input")
|
||||
async def on_text_input(_processor, text):
|
||||
pending_text_inputs.append(text)
|
||||
# 前端显示不依赖 interruption 后续事件,必须在打断前先排入发送队列。
|
||||
await queue_transcript("user", text, time_now_iso8601())
|
||||
|
||||
@assistant_aggregator.event_handler("on_interruption_processed")
|
||||
async def on_interruption_processed(_aggregator):
|
||||
if not pending_text_inputs:
|
||||
return
|
||||
text = pending_text_inputs.pop(0)
|
||||
# assistant aggregator 已处理完 interruption,现在再启动下一轮 LLM。
|
||||
await append_user_text_to_context(text, run_llm=True)
|
||||
|
||||
@text_input.event_handler("on_text_append")
|
||||
async def on_text_append(_processor, text):
|
||||
# 静默追加:写进上下文但不打断、不触发推理;transcript 照常上报
|
||||
await queue_transcript("user", text, time_now_iso8601())
|
||||
await append_user_text_to_context(text, run_llm=False)
|
||||
|
||||
@text_input.event_handler("on_client_ready")
|
||||
async def on_client_ready(_processor):
|
||||
nonlocal greeting_transcript_sent
|
||||
if cfg.greeting and not greeting_transcript_sent:
|
||||
greeting_transcript_sent = True
|
||||
await queue_transcript("assistant", cfg.greeting, time_now_iso8601())
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _client):
|
||||
if cfg.greeting:
|
||||
await worker.queue_frame(TTSSpeakFrame(cfg.greeting))
|
||||
context.add_message({"role": "assistant", "content": cfg.greeting})
|
||||
await worker.queue_frame(TTSSpeakFrame(cfg.greeting, append_to_context=False))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(_transport, _client):
|
||||
|
||||
@@ -54,6 +54,11 @@ function errorMessage(error: unknown, fallback: string): string {
|
||||
return fallback;
|
||||
}
|
||||
|
||||
function messageOrder(message: ChatMessage): number {
|
||||
const timestamp = Date.parse(message.timestamp);
|
||||
return Number.isNaN(timestamp) ? Number.MAX_SAFE_INTEGER : timestamp;
|
||||
}
|
||||
|
||||
function microphoneErrorMessage(error: unknown): string {
|
||||
if (error instanceof DOMException) {
|
||||
if (error.name === "NotAllowedError") {
|
||||
@@ -137,6 +142,7 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
const channel = dataChannelRef.current;
|
||||
dataChannelRef.current = null;
|
||||
if (channel) {
|
||||
channel.onopen = null;
|
||||
channel.onmessage = null;
|
||||
channel.close();
|
||||
}
|
||||
@@ -290,6 +296,9 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
// 由浏览器侧主动创建,后端 SmallWebRTCConnection 的 on("datachannel") 会接住。
|
||||
const channel = pc.createDataChannel("chat");
|
||||
dataChannelRef.current = channel;
|
||||
channel.onopen = () => {
|
||||
channel.send(JSON.stringify({ type: "client-ready" }));
|
||||
};
|
||||
channel.onmessage = (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
@@ -309,7 +318,14 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
? msg.timestamp
|
||||
: new Date().toISOString(),
|
||||
};
|
||||
setMessages((prev) => [...prev, next]);
|
||||
setMessages((prev) =>
|
||||
[...prev, next].sort(
|
||||
(a, b) =>
|
||||
messageOrder(a) - messageOrder(b) ||
|
||||
Number(a.id.replace("msg-", "")) -
|
||||
Number(b.id.replace("msg-", "")),
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
/* 非 JSON / 未知消息,忽略 */
|
||||
|
||||
Reference in New Issue
Block a user