Enhance LLM text streaming and message handling in backend and frontend
- Introduce event handlers in PassthroughLLMAssistantAggregator for managing LLM text streaming, including start, delta, and end events. - Implement a new method to finalize text streams, ensuring proper handling of interruptions. - Update useVoicePreview to support new message types for LLM text streaming, allowing real-time updates to chat messages. - Enhance message sorting logic to maintain order based on timestamps and sequence numbers, improving user experience during voice interactions.
This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
对应 dograh 的 pipeline_builder.py + run_pipeline.py(已砍掉 workflow 引擎/DB/录音/指标)。
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
from models import AssistantConfig
|
||||
from services.pipecat.service_factory import create_services
|
||||
@@ -17,6 +19,7 @@ from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
TextFrame,
|
||||
@@ -109,10 +112,35 @@ class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._register_event_handler("on_interruption_processed")
|
||||
self._register_event_handler("on_assistant_text_start")
|
||||
self._register_event_handler("on_assistant_text_delta")
|
||||
self._register_event_handler("on_assistant_text_end")
|
||||
self._stream_turn_id: str | None = None
|
||||
self._stream_timestamp = ""
|
||||
self._stream_text = ""
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._stream_turn_id = uuid4().hex
|
||||
self._stream_timestamp = time_now_iso8601()
|
||||
self._stream_text = ""
|
||||
await self._call_event_handler(
|
||||
"on_assistant_text_start",
|
||||
self._stream_turn_id,
|
||||
self._stream_timestamp,
|
||||
)
|
||||
elif isinstance(frame, LLMTextFrame) and self._stream_turn_id:
|
||||
self._stream_text += frame.text
|
||||
await self._call_event_handler(
|
||||
"on_assistant_text_delta",
|
||||
self._stream_turn_id,
|
||||
frame.text,
|
||||
)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._finish_text_stream(interrupted=False)
|
||||
|
||||
# LLMAssistantAggregator 默认会消费这些帧。放在 TTS 前用于中断时保存
|
||||
# 已生成前缀时,必须显式透传,否则 TTS 收不到任何 LLM 回复。
|
||||
if isinstance(
|
||||
@@ -121,8 +149,22 @@ class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._finish_text_stream(interrupted=True)
|
||||
await self._call_event_handler("on_interruption_processed")
|
||||
|
||||
async def _finish_text_stream(self, *, interrupted: bool):
|
||||
if not self._stream_turn_id:
|
||||
return
|
||||
await self._call_event_handler(
|
||||
"on_assistant_text_end",
|
||||
self._stream_turn_id,
|
||||
self._stream_text,
|
||||
interrupted,
|
||||
)
|
||||
self._stream_turn_id = None
|
||||
self._stream_timestamp = ""
|
||||
self._stream_text = ""
|
||||
|
||||
|
||||
async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
"""在给定 transport 上构建并运行管线,直到连接结束。
|
||||
@@ -206,10 +248,42 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
async def on_user_turn_stopped(_aggregator, _strategy, message):
|
||||
await queue_transcript("user", message.content, message.timestamp)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(_aggregator, message):
|
||||
# 助手半句此刻已写入上下文,上报为 transcript
|
||||
await queue_transcript("assistant", message.content, message.timestamp)
|
||||
@assistant_aggregator.event_handler("on_assistant_text_start")
|
||||
async def on_assistant_text_start(_aggregator, turn_id, timestamp):
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "assistant-text-start",
|
||||
"turn_id": turn_id,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_text_delta")
|
||||
async def on_assistant_text_delta(_aggregator, turn_id, delta):
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "assistant-text-delta",
|
||||
"turn_id": turn_id,
|
||||
"delta": delta,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_text_end")
|
||||
async def on_assistant_text_end(_aggregator, turn_id, content, interrupted):
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "assistant-text-end",
|
||||
"turn_id": turn_id,
|
||||
"content": content,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@text_input.event_handler("on_text_input")
|
||||
async def on_text_input(_processor, text):
|
||||
|
||||
@@ -29,6 +29,9 @@ export type ChatMessage = {
|
||||
content: string;
|
||||
/** 后端给的 ISO 时间戳 */
|
||||
timestamp: string;
|
||||
sequence: number;
|
||||
turnId?: string;
|
||||
streaming?: boolean;
|
||||
};
|
||||
|
||||
// http→ws、https→wss,自动跟随 API 基址(同源反代时也对)
|
||||
@@ -59,6 +62,12 @@ function messageOrder(message: ChatMessage): number {
|
||||
return Number.isNaN(timestamp) ? Number.MAX_SAFE_INTEGER : timestamp;
|
||||
}
|
||||
|
||||
function sortMessages(messages: ChatMessage[]): ChatMessage[] {
|
||||
return messages.sort(
|
||||
(a, b) => messageOrder(a) - messageOrder(b) || a.sequence - b.sequence,
|
||||
);
|
||||
}
|
||||
|
||||
function microphoneErrorMessage(error: unknown): string {
|
||||
if (error instanceof DOMException) {
|
||||
if (error.name === "NotAllowedError") {
|
||||
@@ -303,6 +312,54 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
if (
|
||||
msg?.type === "assistant-text-start" &&
|
||||
typeof msg.turn_id === "string"
|
||||
) {
|
||||
messageSeqRef.current += 1;
|
||||
const next: ChatMessage = {
|
||||
id: `assistant-${msg.turn_id}`,
|
||||
role: "assistant",
|
||||
content: "",
|
||||
timestamp:
|
||||
typeof msg.timestamp === "string"
|
||||
? msg.timestamp
|
||||
: new Date().toISOString(),
|
||||
sequence: messageSeqRef.current,
|
||||
turnId: msg.turn_id,
|
||||
streaming: true,
|
||||
};
|
||||
setMessages((prev) => sortMessages([...prev, next]));
|
||||
} else if (
|
||||
msg?.type === "assistant-text-delta" &&
|
||||
typeof msg.turn_id === "string" &&
|
||||
typeof msg.delta === "string"
|
||||
) {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.turnId === msg.turn_id
|
||||
? { ...message, content: message.content + msg.delta }
|
||||
: message,
|
||||
),
|
||||
);
|
||||
} else if (
|
||||
msg?.type === "assistant-text-end" &&
|
||||
typeof msg.turn_id === "string"
|
||||
) {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.turnId === msg.turn_id
|
||||
? {
|
||||
...message,
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: message.content,
|
||||
streaming: false,
|
||||
}
|
||||
: message,
|
||||
),
|
||||
);
|
||||
} else if (
|
||||
msg?.type === "transcript" &&
|
||||
(msg.role === "user" || msg.role === "assistant") &&
|
||||
typeof msg.content === "string" &&
|
||||
@@ -317,15 +374,9 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
typeof msg.timestamp === "string"
|
||||
? msg.timestamp
|
||||
: new Date().toISOString(),
|
||||
sequence: messageSeqRef.current,
|
||||
};
|
||||
setMessages((prev) =>
|
||||
[...prev, next].sort(
|
||||
(a, b) =>
|
||||
messageOrder(a) - messageOrder(b) ||
|
||||
Number(a.id.replace("msg-", "")) -
|
||||
Number(b.id.replace("msg-", "")),
|
||||
),
|
||||
);
|
||||
setMessages((prev) => sortMessages([...prev, next]));
|
||||
}
|
||||
} catch {
|
||||
/* 非 JSON / 未知消息,忽略 */
|
||||
|
||||
Reference in New Issue
Block a user