Enhance LLM text streaming and message handling in backend and frontend

- Introduce event handlers in PassthroughLLMAssistantAggregator for managing LLM text streaming, including start, delta, and end events.
- Implement a new method to finalize text streams, ensuring proper handling of interruptions.
- Update useVoicePreview to support new message types for LLM text streaming, allowing real-time updates to chat messages.
- Enhance message sorting logic to maintain order based on timestamps and sequence numbers, improving user experience during voice interactions.
This commit is contained in:
Xin Wang
2026-06-14 22:18:21 +08:00
parent b749d2e075
commit d55b87cfbf
2 changed files with 137 additions and 12 deletions

View File

@@ -6,6 +6,8 @@
对应 dograh 的 pipeline_builder.py + run_pipeline.py(已砍掉 workflow 引擎/DB/录音/指标)。
"""
from uuid import uuid4
from loguru import logger
from models import AssistantConfig
from services.pipecat.service_factory import create_services
@@ -17,6 +19,7 @@ from pipecat.frames.frames import (
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
LLMMessagesAppendFrame,
OutputTransportMessageUrgentFrame,
TextFrame,
@@ -109,10 +112,35 @@ class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_event_handler("on_interruption_processed")
self._register_event_handler("on_assistant_text_start")
self._register_event_handler("on_assistant_text_delta")
self._register_event_handler("on_assistant_text_end")
self._stream_turn_id: str | None = None
self._stream_timestamp = ""
self._stream_text = ""
async def process_frame(self, frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
self._stream_turn_id = uuid4().hex
self._stream_timestamp = time_now_iso8601()
self._stream_text = ""
await self._call_event_handler(
"on_assistant_text_start",
self._stream_turn_id,
self._stream_timestamp,
)
elif isinstance(frame, LLMTextFrame) and self._stream_turn_id:
self._stream_text += frame.text
await self._call_event_handler(
"on_assistant_text_delta",
self._stream_turn_id,
frame.text,
)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._finish_text_stream(interrupted=False)
# LLMAssistantAggregator 默认会消费这些帧。放在 TTS 前用于中断时保存
# 已生成前缀时,必须显式透传,否则 TTS 收不到任何 LLM 回复。
if isinstance(
@@ -121,8 +149,22 @@ class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
):
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
await self._finish_text_stream(interrupted=True)
await self._call_event_handler("on_interruption_processed")
async def _finish_text_stream(self, *, interrupted: bool):
if not self._stream_turn_id:
return
await self._call_event_handler(
"on_assistant_text_end",
self._stream_turn_id,
self._stream_text,
interrupted,
)
self._stream_turn_id = None
self._stream_timestamp = ""
self._stream_text = ""
async def run_pipeline(transport, cfg: AssistantConfig) -> None:
"""在给定 transport 上构建并运行管线,直到连接结束。
@@ -206,10 +248,42 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
async def on_user_turn_stopped(_aggregator, _strategy, message):
await queue_transcript("user", message.content, message.timestamp)
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
async def on_assistant_turn_stopped(_aggregator, message):
# 助手半句此刻已写入上下文,上报为 transcript
await queue_transcript("assistant", message.content, message.timestamp)
@assistant_aggregator.event_handler("on_assistant_text_start")
async def on_assistant_text_start(_aggregator, turn_id, timestamp):
await worker.queue_frame(
OutputTransportMessageUrgentFrame(
message={
"type": "assistant-text-start",
"turn_id": turn_id,
"timestamp": timestamp,
}
)
)
@assistant_aggregator.event_handler("on_assistant_text_delta")
async def on_assistant_text_delta(_aggregator, turn_id, delta):
await worker.queue_frame(
OutputTransportMessageUrgentFrame(
message={
"type": "assistant-text-delta",
"turn_id": turn_id,
"delta": delta,
}
)
)
@assistant_aggregator.event_handler("on_assistant_text_end")
async def on_assistant_text_end(_aggregator, turn_id, content, interrupted):
await worker.queue_frame(
OutputTransportMessageUrgentFrame(
message={
"type": "assistant-text-end",
"turn_id": turn_id,
"content": content,
"interrupted": interrupted,
}
)
)
@text_input.event_handler("on_text_input")
async def on_text_input(_processor, text):

View File

@@ -29,6 +29,9 @@ export type ChatMessage = {
content: string;
/** 后端给的 ISO 时间戳 */
timestamp: string;
sequence: number;
turnId?: string;
streaming?: boolean;
};
// http→ws、https→wss,自动跟随 API 基址(同源反代时也对)
@@ -59,6 +62,12 @@ function messageOrder(message: ChatMessage): number {
return Number.isNaN(timestamp) ? Number.MAX_SAFE_INTEGER : timestamp;
}
function sortMessages(messages: ChatMessage[]): ChatMessage[] {
return messages.sort(
(a, b) => messageOrder(a) - messageOrder(b) || a.sequence - b.sequence,
);
}
function microphoneErrorMessage(error: unknown): string {
if (error instanceof DOMException) {
if (error.name === "NotAllowedError") {
@@ -303,6 +312,54 @@ export function useVoicePreview(assistantId: string | null) {
try {
const msg = JSON.parse(event.data);
if (
msg?.type === "assistant-text-start" &&
typeof msg.turn_id === "string"
) {
messageSeqRef.current += 1;
const next: ChatMessage = {
id: `assistant-${msg.turn_id}`,
role: "assistant",
content: "",
timestamp:
typeof msg.timestamp === "string"
? msg.timestamp
: new Date().toISOString(),
sequence: messageSeqRef.current,
turnId: msg.turn_id,
streaming: true,
};
setMessages((prev) => sortMessages([...prev, next]));
} else if (
msg?.type === "assistant-text-delta" &&
typeof msg.turn_id === "string" &&
typeof msg.delta === "string"
) {
setMessages((prev) =>
prev.map((message) =>
message.turnId === msg.turn_id
? { ...message, content: message.content + msg.delta }
: message,
),
);
} else if (
msg?.type === "assistant-text-end" &&
typeof msg.turn_id === "string"
) {
setMessages((prev) =>
prev.map((message) =>
message.turnId === msg.turn_id
? {
...message,
content:
typeof msg.content === "string"
? msg.content
: message.content,
streaming: false,
}
: message,
),
);
} else if (
msg?.type === "transcript" &&
(msg.role === "user" || msg.role === "assistant") &&
typeof msg.content === "string" &&
@@ -317,15 +374,9 @@ export function useVoicePreview(assistantId: string | null) {
typeof msg.timestamp === "string"
? msg.timestamp
: new Date().toISOString(),
sequence: messageSeqRef.current,
};
setMessages((prev) =>
[...prev, next].sort(
(a, b) =>
messageOrder(a) - messageOrder(b) ||
Number(a.id.replace("msg-", "")) -
Number(b.id.replace("msg-", "")),
),
);
setMessages((prev) => sortMessages([...prev, next]));
}
} catch {
/* 非 JSON / 未知消息,忽略 */