Compare commits
3 Commits
90e3e8a0c0
...
d55b87cfbf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d55b87cfbf | ||
|
|
b749d2e075 | ||
|
|
86d9acce78 |
@@ -6,6 +6,8 @@
|
||||
对应 dograh 的 pipeline_builder.py + run_pipeline.py(已砍掉 workflow 引擎/DB/录音/指标)。
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
from models import AssistantConfig
|
||||
from services.pipecat.service_factory import create_services
|
||||
@@ -13,17 +15,22 @@ from services.pipecat.service_factory import create_services
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
InputTextRawFrame,
|
||||
InputTransportMessageFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMAssistantAggregator,
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -57,11 +64,21 @@ def _text_input(message) -> tuple[str, bool] | None:
|
||||
|
||||
|
||||
class TextInputProcessor(FrameProcessor):
|
||||
"""把 transport 文字消息转换成级联与实时 LLM 都能消费的帧。"""
|
||||
"""把 transport 文字消息转换成 LLM 可消费的帧。
|
||||
|
||||
run_immediately(默认/打断):先通过 on_text_input 事件把用户文字交给
|
||||
run_pipeline 登记,再用 broadcast_interruption() 打断当前播报。新的 LLM
|
||||
回复由 assistant aggregator 确认处理完 interruption 后触发。
|
||||
run_immediately=False(RTVI send-text 静默追加):仅把文字写进上下文,
|
||||
不打断、不触发推理。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 立即触发的文字(含打断语义)走 on_text_input;静默追加另走一条事件
|
||||
self._register_event_handler("on_text_input")
|
||||
self._register_event_handler("on_text_append")
|
||||
self._register_event_handler("on_client_ready")
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -70,6 +87,10 @@ class TextInputProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
if isinstance(frame.message, dict) and frame.message.get("type") == "client-ready":
|
||||
await self._call_event_handler("on_client_ready")
|
||||
return
|
||||
|
||||
parsed = _text_input(frame.message)
|
||||
if not parsed:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -77,17 +98,72 @@ class TextInputProcessor(FrameProcessor):
|
||||
|
||||
text, run_immediately = parsed
|
||||
if run_immediately:
|
||||
# 先登记文字再打断。下一轮 LLM 由 assistant aggregator 在真正处理完
|
||||
# InterruptionFrame 后触发,避免新回复被这次 interruption 一起取消。
|
||||
await self._call_event_handler("on_text_input", text)
|
||||
await self.broadcast_interruption()
|
||||
else:
|
||||
await self._call_event_handler("on_text_append", text)
|
||||
|
||||
await self.push_frame(
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
run_llm=run_immediately,
|
||||
|
||||
class PassthroughLLMAssistantAggregator(LLMAssistantAggregator):
|
||||
"""聚合 LLM 回复进上下文,同时继续把回复帧交给下游 TTS。"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._register_event_handler("on_interruption_processed")
|
||||
self._register_event_handler("on_assistant_text_start")
|
||||
self._register_event_handler("on_assistant_text_delta")
|
||||
self._register_event_handler("on_assistant_text_end")
|
||||
self._stream_turn_id: str | None = None
|
||||
self._stream_timestamp = ""
|
||||
self._stream_text = ""
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._stream_turn_id = uuid4().hex
|
||||
self._stream_timestamp = time_now_iso8601()
|
||||
self._stream_text = ""
|
||||
await self._call_event_handler(
|
||||
"on_assistant_text_start",
|
||||
self._stream_turn_id,
|
||||
self._stream_timestamp,
|
||||
)
|
||||
elif isinstance(frame, LLMTextFrame) and self._stream_turn_id:
|
||||
self._stream_text += frame.text
|
||||
await self._call_event_handler(
|
||||
"on_assistant_text_delta",
|
||||
self._stream_turn_id,
|
||||
frame.text,
|
||||
)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._finish_text_stream(interrupted=False)
|
||||
|
||||
# LLMAssistantAggregator 默认会消费这些帧。放在 TTS 前用于中断时保存
|
||||
# 已生成前缀时,必须显式透传,否则 TTS 收不到任何 LLM 回复。
|
||||
if isinstance(
|
||||
frame,
|
||||
(LLMFullResponseStartFrame, LLMFullResponseEndFrame, TextFrame),
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._finish_text_stream(interrupted=True)
|
||||
await self._call_event_handler("on_interruption_processed")
|
||||
|
||||
async def _finish_text_stream(self, *, interrupted: bool):
|
||||
if not self._stream_turn_id:
|
||||
return
|
||||
await self._call_event_handler(
|
||||
"on_assistant_text_end",
|
||||
self._stream_turn_id,
|
||||
self._stream_text,
|
||||
interrupted,
|
||||
)
|
||||
if run_immediately:
|
||||
await self.push_frame(InputTextRawFrame(text=text))
|
||||
await self._call_event_handler("on_text_input", text)
|
||||
self._stream_turn_id = None
|
||||
self._stream_timestamp = ""
|
||||
self._stream_text = ""
|
||||
|
||||
|
||||
async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
@@ -103,9 +179,9 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
stt, llm, tts = create_services(cfg)
|
||||
|
||||
context = LLMContext(messages=[{"role": "system", "content": cfg.prompt}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
user_aggregator = LLMUserAggregator(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
params=LLMUserAggregatorParams(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
start=[
|
||||
@@ -117,6 +193,7 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
),
|
||||
),
|
||||
)
|
||||
assistant_aggregator = PassthroughLLMAssistantAggregator(context)
|
||||
text_input = TextInputProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
@@ -126,9 +203,12 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
# Aggregate the streamed LLM text before TTS. On interruption,
|
||||
# Pipecat commits the generated prefix immediately instead of
|
||||
# waiting for a TTS provider to emit spoken-text/timestamp frames.
|
||||
assistant_aggregator,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -153,22 +233,90 @@ async def run_pipeline(transport, cfg: AssistantConfig) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
greeting_transcript_sent = False
|
||||
pending_text_inputs: list[str] = []
|
||||
|
||||
async def append_user_text_to_context(text: str, *, run_llm: bool) -> None:
|
||||
await worker.queue_frame(
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
run_llm=run_llm,
|
||||
)
|
||||
)
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(_aggregator, _strategy, message):
|
||||
await queue_transcript("user", message.content, message.timestamp)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(_aggregator, message):
|
||||
await queue_transcript("assistant", message.content, message.timestamp)
|
||||
@assistant_aggregator.event_handler("on_assistant_text_start")
|
||||
async def on_assistant_text_start(_aggregator, turn_id, timestamp):
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "assistant-text-start",
|
||||
"turn_id": turn_id,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_text_delta")
|
||||
async def on_assistant_text_delta(_aggregator, turn_id, delta):
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "assistant-text-delta",
|
||||
"turn_id": turn_id,
|
||||
"delta": delta,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_text_end")
|
||||
async def on_assistant_text_end(_aggregator, turn_id, content, interrupted):
|
||||
await worker.queue_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": "assistant-text-end",
|
||||
"turn_id": turn_id,
|
||||
"content": content,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@text_input.event_handler("on_text_input")
|
||||
async def on_text_input(_processor, text):
|
||||
pending_text_inputs.append(text)
|
||||
# 前端显示不依赖 interruption 后续事件,必须在打断前先排入发送队列。
|
||||
await queue_transcript("user", text, time_now_iso8601())
|
||||
|
||||
@assistant_aggregator.event_handler("on_interruption_processed")
|
||||
async def on_interruption_processed(_aggregator):
|
||||
if not pending_text_inputs:
|
||||
return
|
||||
text = pending_text_inputs.pop(0)
|
||||
# assistant aggregator 已处理完 interruption,现在再启动下一轮 LLM。
|
||||
await append_user_text_to_context(text, run_llm=True)
|
||||
|
||||
@text_input.event_handler("on_text_append")
|
||||
async def on_text_append(_processor, text):
|
||||
# 静默追加:写进上下文但不打断、不触发推理;transcript 照常上报
|
||||
await queue_transcript("user", text, time_now_iso8601())
|
||||
await append_user_text_to_context(text, run_llm=False)
|
||||
|
||||
@text_input.event_handler("on_client_ready")
|
||||
async def on_client_ready(_processor):
|
||||
nonlocal greeting_transcript_sent
|
||||
if cfg.greeting and not greeting_transcript_sent:
|
||||
greeting_transcript_sent = True
|
||||
await queue_transcript("assistant", cfg.greeting, time_now_iso8601())
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _client):
|
||||
if cfg.greeting:
|
||||
await worker.queue_frame(TTSSpeakFrame(cfg.greeting))
|
||||
context.add_message({"role": "assistant", "content": cfg.greeting})
|
||||
await worker.queue_frame(TTSSpeakFrame(cfg.greeting, append_to_context=False))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(_transport, _client):
|
||||
|
||||
@@ -1838,7 +1838,7 @@ function DebugVoicePanel({
|
||||
messages,
|
||||
audioInputs,
|
||||
selectedDeviceId,
|
||||
setSelectedDeviceId,
|
||||
selectDevice,
|
||||
sendText,
|
||||
connect,
|
||||
disconnect,
|
||||
@@ -1868,7 +1868,7 @@ function DebugVoicePanel({
|
||||
assistantId={assistantId}
|
||||
audioInputs={audioInputs}
|
||||
selectedDeviceId={selectedDeviceId}
|
||||
setSelectedDeviceId={setSelectedDeviceId}
|
||||
selectDevice={selectDevice}
|
||||
connect={connect}
|
||||
disconnect={disconnect}
|
||||
/>
|
||||
@@ -1941,9 +1941,8 @@ function DebugVoicePanel({
|
||||
<Select
|
||||
value={selectedDeviceId || "default"}
|
||||
onValueChange={(value) =>
|
||||
setSelectedDeviceId(value === "default" ? "" : value)
|
||||
selectDevice(value === "default" ? "" : value)
|
||||
}
|
||||
disabled={recording}
|
||||
>
|
||||
<SelectTrigger
|
||||
size="sm"
|
||||
@@ -2044,7 +2043,7 @@ function VoiceSessionControls({
|
||||
assistantId,
|
||||
audioInputs,
|
||||
selectedDeviceId,
|
||||
setSelectedDeviceId,
|
||||
selectDevice,
|
||||
connect,
|
||||
disconnect,
|
||||
}: {
|
||||
@@ -2054,7 +2053,7 @@ function VoiceSessionControls({
|
||||
assistantId: string | null;
|
||||
audioInputs: MediaDeviceInfo[];
|
||||
selectedDeviceId: string;
|
||||
setSelectedDeviceId: (deviceId: string) => void;
|
||||
selectDevice: (deviceId: string) => void;
|
||||
connect: () => Promise<void>;
|
||||
disconnect: () => void;
|
||||
}) {
|
||||
@@ -2096,9 +2095,8 @@ function VoiceSessionControls({
|
||||
<Select
|
||||
value={selectedDeviceId || "default"}
|
||||
onValueChange={(value) =>
|
||||
setSelectedDeviceId(value === "default" ? "" : value)
|
||||
selectDevice(value === "default" ? "" : value)
|
||||
}
|
||||
disabled={recording}
|
||||
>
|
||||
<SelectTrigger
|
||||
size="sm"
|
||||
|
||||
@@ -29,6 +29,9 @@ export type ChatMessage = {
|
||||
content: string;
|
||||
/** 后端给的 ISO 时间戳 */
|
||||
timestamp: string;
|
||||
sequence: number;
|
||||
turnId?: string;
|
||||
streaming?: boolean;
|
||||
};
|
||||
|
||||
// http→ws、https→wss,自动跟随 API 基址(同源反代时也对)
|
||||
@@ -54,6 +57,17 @@ function errorMessage(error: unknown, fallback: string): string {
|
||||
return fallback;
|
||||
}
|
||||
|
||||
function messageOrder(message: ChatMessage): number {
|
||||
const timestamp = Date.parse(message.timestamp);
|
||||
return Number.isNaN(timestamp) ? Number.MAX_SAFE_INTEGER : timestamp;
|
||||
}
|
||||
|
||||
function sortMessages(messages: ChatMessage[]): ChatMessage[] {
|
||||
return messages.sort(
|
||||
(a, b) => messageOrder(a) - messageOrder(b) || a.sequence - b.sequence,
|
||||
);
|
||||
}
|
||||
|
||||
function microphoneErrorMessage(error: unknown): string {
|
||||
if (error instanceof DOMException) {
|
||||
if (error.name === "NotAllowedError") {
|
||||
@@ -137,6 +151,7 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
const channel = dataChannelRef.current;
|
||||
dataChannelRef.current = null;
|
||||
if (channel) {
|
||||
channel.onopen = null;
|
||||
channel.onmessage = null;
|
||||
channel.close();
|
||||
}
|
||||
@@ -290,10 +305,61 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
// 由浏览器侧主动创建,后端 SmallWebRTCConnection 的 on("datachannel") 会接住。
|
||||
const channel = pc.createDataChannel("chat");
|
||||
dataChannelRef.current = channel;
|
||||
channel.onopen = () => {
|
||||
channel.send(JSON.stringify({ type: "client-ready" }));
|
||||
};
|
||||
channel.onmessage = (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
if (
|
||||
msg?.type === "assistant-text-start" &&
|
||||
typeof msg.turn_id === "string"
|
||||
) {
|
||||
messageSeqRef.current += 1;
|
||||
const next: ChatMessage = {
|
||||
id: `assistant-${msg.turn_id}`,
|
||||
role: "assistant",
|
||||
content: "",
|
||||
timestamp:
|
||||
typeof msg.timestamp === "string"
|
||||
? msg.timestamp
|
||||
: new Date().toISOString(),
|
||||
sequence: messageSeqRef.current,
|
||||
turnId: msg.turn_id,
|
||||
streaming: true,
|
||||
};
|
||||
setMessages((prev) => sortMessages([...prev, next]));
|
||||
} else if (
|
||||
msg?.type === "assistant-text-delta" &&
|
||||
typeof msg.turn_id === "string" &&
|
||||
typeof msg.delta === "string"
|
||||
) {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.turnId === msg.turn_id
|
||||
? { ...message, content: message.content + msg.delta }
|
||||
: message,
|
||||
),
|
||||
);
|
||||
} else if (
|
||||
msg?.type === "assistant-text-end" &&
|
||||
typeof msg.turn_id === "string"
|
||||
) {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.turnId === msg.turn_id
|
||||
? {
|
||||
...message,
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: message.content,
|
||||
streaming: false,
|
||||
}
|
||||
: message,
|
||||
),
|
||||
);
|
||||
} else if (
|
||||
msg?.type === "transcript" &&
|
||||
(msg.role === "user" || msg.role === "assistant") &&
|
||||
typeof msg.content === "string" &&
|
||||
@@ -308,8 +374,9 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
typeof msg.timestamp === "string"
|
||||
? msg.timestamp
|
||||
: new Date().toISOString(),
|
||||
sequence: messageSeqRef.current,
|
||||
};
|
||||
setMessages((prev) => [...prev, next]);
|
||||
setMessages((prev) => sortMessages([...prev, next]));
|
||||
}
|
||||
} catch {
|
||||
/* 非 JSON / 未知消息,忽略 */
|
||||
@@ -373,6 +440,56 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
}
|
||||
}, [assistantId, fail, refreshDevices]);
|
||||
|
||||
// 选择麦克风:更新选择;若会话正在发送麦克风音频,则用 WebRTC replaceTrack
|
||||
// 热切换轨道(无需重新协商),并把波形可视化重新接到新流。
|
||||
// 未连接时仅记下选择,留待下次 connect 生效。
|
||||
const selectDevice = useCallback(
|
||||
async (deviceId: string) => {
|
||||
setSelectedDeviceId(deviceId);
|
||||
selectedDeviceIdRef.current = deviceId;
|
||||
|
||||
const pc = pcRef.current;
|
||||
if (!pc) return;
|
||||
// 只有本就在发送麦克风音频(存在 audio sender 轨道)时才热切换;
|
||||
// 仅收听模式下加麦克风需重新协商,这里不处理,留到下次连接。
|
||||
const sender = pc
|
||||
.getSenders()
|
||||
.find((s) => s.track?.kind === "audio");
|
||||
if (!sender) return;
|
||||
|
||||
try {
|
||||
const audioConstraints: MediaTrackConstraints = {
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
autoGainControl: true,
|
||||
};
|
||||
if (deviceId) audioConstraints.deviceId = { exact: deviceId };
|
||||
const newStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: audioConstraints,
|
||||
});
|
||||
// 切换期间可能已断开,丢弃刚拿到的流
|
||||
if (pcRef.current !== pc) {
|
||||
newStream.getTracks().forEach((t) => t.stop());
|
||||
return;
|
||||
}
|
||||
const newTrack = newStream.getAudioTracks()[0];
|
||||
if (!newTrack) {
|
||||
newStream.getTracks().forEach((t) => t.stop());
|
||||
return;
|
||||
}
|
||||
await sender.replaceTrack(newTrack);
|
||||
// 旧轨道停掉,新流替换(波形/分析器随 localStream 变化自动重连)
|
||||
localStreamRef.current?.getTracks().forEach((t) => t.stop());
|
||||
localStreamRef.current = newStream;
|
||||
setLocalStream(newStream);
|
||||
setMicWarning(null);
|
||||
} catch (mediaError) {
|
||||
setMicWarning(microphoneErrorMessage(mediaError));
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// 发送文字消息:后端先打断当前播报,再按用户输入触发新回复。
|
||||
// 成功返回 true;通道未就绪(未开始对话/连接中)返回 false。
|
||||
const sendText = useCallback((text: string): boolean => {
|
||||
@@ -396,6 +513,7 @@ export function useVoicePreview(assistantId: string | null) {
|
||||
audioInputs,
|
||||
selectedDeviceId,
|
||||
setSelectedDeviceId,
|
||||
selectDevice,
|
||||
sendText,
|
||||
connect,
|
||||
disconnect,
|
||||
|
||||
Reference in New Issue
Block a user