Add text_msg_prompt tool to DuplexPipeline and Assistants. Update DebugDrawer to handle text message prompts, including parameter validation and state management for displaying messages. Ensure integration with existing tools and maintain functionality across components.

This commit is contained in:
Xin Wang
2026-02-27 16:47:49 +08:00
parent cdd8275e35
commit 95c6e93a9c
2 changed files with 207 additions and 33 deletions

View File

@@ -157,6 +157,17 @@ class DuplexPipeline:
"required": ["msg"],
},
},
"text_msg_prompt": {
"name": "text_msg_prompt",
"description": "Show a text message prompt dialog on client side",
"parameters": {
"type": "object",
"properties": {
"msg": {"type": "string", "description": "Message text to display"},
},
"required": ["msg"],
},
},
}
_DEFAULT_CLIENT_EXECUTORS = frozenset({
"turn_on_camera",
@@ -164,6 +175,7 @@ class DuplexPipeline:
"increase_volume",
"decrease_volume",
"voice_message_prompt",
"text_msg_prompt",
})
def __init__(
@@ -559,6 +571,10 @@ class DuplexPipeline:
data = event.get("data")
if not isinstance(data, dict):
data = {}
explicit_turn_id = str(event.get("turn_id") or "").strip() or None
explicit_utterance_id = str(event.get("utterance_id") or "").strip() or None
explicit_response_id = str(event.get("response_id") or "").strip() or None
explicit_tts_id = str(event.get("tts_id") or "").strip() or None
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
@@ -567,9 +583,29 @@ class DuplexPipeline:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
if explicit_turn_id:
data["turn_id"] = explicit_turn_id
if explicit_utterance_id:
data["utterance_id"] = explicit_utterance_id
if explicit_response_id:
data["response_id"] = explicit_response_id
if explicit_tts_id:
data["tts_id"] = explicit_tts_id
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
if k in {
"type",
"timestamp",
"sessionId",
"seq",
"source",
"trackId",
"data",
"turn_id",
"utterance_id",
"response_id",
"tts_id",
}:
continue
data.setdefault(k, v)
@@ -1027,25 +1063,50 @@ class DuplexPipeline:
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
async def _emit_llm_delta(
self,
text: str,
*,
turn_id: Optional[str] = None,
utterance_id: Optional[str] = None,
response_id: Optional[str] = None,
) -> None:
event = {
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
}
if turn_id:
event["turn_id"] = turn_id
if utterance_id:
event["utterance_id"] = utterance_id
if response_id:
event["response_id"] = response_id
await self._send_event(
event,
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
async def _flush_pending_llm_delta(
self,
*,
turn_id: Optional[str] = None,
utterance_id: Optional[str] = None,
response_id: Optional[str] = None,
) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
await self._emit_llm_delta(
chunk,
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events."""
@@ -1761,7 +1822,9 @@ class DuplexPipeline:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
turn_id = self._current_turn_id
utterance_id = self._current_utterance_id
response_id = self._start_response()
# Start latency tracking
self._turn_start_time = time.time()
self._first_audio_sent = False
@@ -1795,7 +1858,11 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
await self._flush_pending_llm_delta()
await self._flush_pending_llm_delta(
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
@@ -1869,7 +1936,11 @@ class DuplexPipeline:
self._last_llm_delta_emit_ms <= 0.0
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
):
await self._flush_pending_llm_delta()
await self._flush_pending_llm_delta(
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
if use_engine_sentence_split:
while True:
@@ -1905,7 +1976,10 @@ class DuplexPipeline:
**ev(
"output.audio.start",
trackId=self.track_audio_out,
)
),
"turn_id": turn_id,
"utterance_id": utterance_id,
"response_id": response_id,
},
priority=30,
)
@@ -1915,13 +1989,20 @@ class DuplexPipeline:
sentence,
fade_in_ms=0,
fade_out_ms=8,
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
if use_engine_sentence_split:
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
else:
remaining_text = sentence_buffer.strip()
await self._flush_pending_llm_delta()
await self._flush_pending_llm_delta(
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
if (
self._tts_output_enabled()
and remaining_text
@@ -1935,7 +2016,10 @@ class DuplexPipeline:
**ev(
"output.audio.start",
trackId=self.track_audio_out,
)
),
"turn_id": turn_id,
"utterance_id": utterance_id,
"response_id": response_id,
},
priority=30,
)
@@ -1944,6 +2028,9 @@ class DuplexPipeline:
remaining_text,
fade_in_ms=0,
fade_out_ms=8,
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
if not tool_calls:
@@ -2007,14 +2094,21 @@ class DuplexPipeline:
]
if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._flush_pending_llm_delta(
turn_id=turn_id,
utterance_id=utterance_id,
response_id=response_id,
)
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.track_audio_out,
text=full_response,
)
),
"turn_id": turn_id,
"utterance_id": utterance_id,
"response_id": response_id,
},
priority=20,
)
@@ -2026,7 +2120,10 @@ class DuplexPipeline:
**ev(
"output.audio.end",
trackId=self.track_audio_out,
)
),
"turn_id": turn_id,
"utterance_id": utterance_id,
"response_id": response_id,
}, priority=10)
# End assistant turn
@@ -2049,7 +2146,15 @@ class DuplexPipeline:
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
async def _speak_sentence(
self,
text: str,
fade_in_ms: int = 0,
fade_out_ms: int = 8,
turn_id: Optional[str] = None,
utterance_id: Optional[str] = None,
response_id: Optional[str] = None,
) -> None:
"""
Synthesize and send a single sentence.
@@ -2086,7 +2191,10 @@ class DuplexPipeline:
"metrics.ttfb",
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
),
"turn_id": turn_id,
"utterance_id": utterance_id,
"response_id": response_id,
}, priority=25)
# Double-check interrupt right before sending audio
@@ -2233,6 +2341,9 @@ class DuplexPipeline:
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
interrupted_turn_id = self._current_turn_id
interrupted_utterance_id = self._current_utterance_id
interrupted_response_id = self._current_response_id
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
@@ -2240,7 +2351,10 @@ class DuplexPipeline:
**ev(
"response.interrupted",
trackId=self.track_audio_out,
)
),
"turn_id": interrupted_turn_id,
"utterance_id": interrupted_utterance_id,
"response_id": interrupted_response_id,
}, priority=0)
# Cancel TTS

View File

@@ -1689,6 +1689,13 @@ const TOOL_PARAMETER_HINTS: Record<string, any> = {
},
required: ['msg'],
},
text_msg_prompt: {
type: 'object',
properties: {
msg: { type: 'string', description: 'Message text to display in debug drawer modal' },
},
required: ['msg'],
},
code_interpreter: {
type: 'object',
properties: {
@@ -1707,6 +1714,7 @@ const DEBUG_CLIENT_TOOLS = [
{ id: 'increase_volume', name: 'increase_volume', description: '调高音量' },
{ id: 'decrease_volume', name: 'decrease_volume', description: '调低音量' },
{ id: 'voice_message_prompt', name: 'voice_message_prompt', description: '语音消息提示' },
{ id: 'text_msg_prompt', name: 'text_msg_prompt', description: '文本消息提示' },
] as const;
const DEBUG_CLIENT_TOOL_ID_SET = new Set<string>(DEBUG_CLIENT_TOOLS.map((item) => item.id));
@@ -1973,6 +1981,7 @@ export const DebugDrawer: React.FC<{
const [inputText, setInputText] = useState('');
const [isLoading, setIsLoading] = useState(false);
const [callStatus, setCallStatus] = useState<'idle' | 'calling' | 'active'>('idle');
const [textPromptDialog, setTextPromptDialog] = useState<{ open: boolean; message: string }>({ open: false, message: '' });
const [textSessionStarted, setTextSessionStarted] = useState(false);
const [wsStatus, setWsStatus] = useState<'disconnected' | 'connecting' | 'ready' | 'error'>('disconnected');
const [wsError, setWsError] = useState('');
@@ -2033,6 +2042,7 @@ export const DebugDrawer: React.FC<{
const assistantDraftIndexRef = useRef<number | null>(null);
const assistantResponseIndexByIdRef = useRef<Map<string, number>>(new Map());
const pendingTtfbByResponseIdRef = useRef<Map<string, number>>(new Map());
const interruptedResponseIdsRef = useRef<Set<string>>(new Set());
const audioCtxRef = useRef<AudioContext | null>(null);
const playbackTimeRef = useRef<number>(0);
const activeAudioSourcesRef = useRef<Set<AudioBufferSourceNode>>(new Set());
@@ -2101,6 +2111,13 @@ export const DebugDrawer: React.FC<{
assistantDraftIndexRef.current = null;
assistantResponseIndexByIdRef.current.clear();
pendingTtfbByResponseIdRef.current.clear();
interruptedResponseIdsRef.current.clear();
};
const extractResponseId = (payload: any): string | undefined => {
const responseIdRaw = payload?.data?.response_id ?? payload?.response_id ?? payload?.responseId;
const responseId = String(responseIdRaw || '').trim();
return responseId || undefined;
};
// Initialize
@@ -2120,6 +2137,7 @@ export const DebugDrawer: React.FC<{
stopVoiceCapture();
stopMedia();
closeWs();
setTextPromptDialog({ open: false, message: '' });
if (audioCtxRef.current) {
void audioCtxRef.current.close();
audioCtxRef.current = null;
@@ -2480,6 +2498,7 @@ export const DebugDrawer: React.FC<{
setCallStatus('idle');
clearResponseTracking();
setMessages([]);
setTextPromptDialog({ open: false, message: '' });
lastUserFinalRef.current = '';
setIsLoading(false);
};
@@ -2903,6 +2922,14 @@ export const DebugDrawer: React.FC<{
}
if (type === 'response.interrupted') {
const interruptedResponseId = extractResponseId(payload);
if (interruptedResponseId) {
interruptedResponseIdsRef.current.add(interruptedResponseId);
if (interruptedResponseIdsRef.current.size > 64) {
const oldest = interruptedResponseIdsRef.current.values().next().value as string | undefined;
if (oldest) interruptedResponseIdsRef.current.delete(oldest);
}
}
assistantDraftIndexRef.current = null;
setIsLoading(false);
stopPlaybackImmediately();
@@ -2913,8 +2940,8 @@ export const DebugDrawer: React.FC<{
const maybeTtfb = Number(payload?.latencyMs ?? payload?.data?.latencyMs);
if (!Number.isFinite(maybeTtfb) || maybeTtfb < 0) return;
const ttfbMs = Math.round(maybeTtfb);
const responseIdRaw = payload?.data?.response_id ?? payload?.response_id ?? payload?.responseId;
const responseId = String(responseIdRaw || '').trim();
const responseId = extractResponseId(payload);
if (responseId && interruptedResponseIdsRef.current.has(responseId)) return;
if (responseId) {
const indexed = assistantResponseIndexByIdRef.current.get(responseId);
if (typeof indexed === 'number') {
@@ -3065,6 +3092,16 @@ export const DebugDrawer: React.FC<{
resultPayload.output = { message: 'speech_synthesis_unavailable', msg };
resultPayload.status = { code: 503, message: 'speech_output_unavailable' };
}
} else if (toolName === 'text_msg_prompt') {
const msg = String(parsedArgs?.msg || '').trim();
if (!msg) {
resultPayload.output = { message: "Missing required argument 'msg'" };
resultPayload.status = { code: 422, message: 'invalid_arguments' };
} else {
setTextPromptDialog({ open: true, message: msg });
resultPayload.output = { message: 'text_prompt_shown', msg };
resultPayload.status = { code: 200, message: 'ok' };
}
}
} catch (err) {
resultPayload.output = {
@@ -3183,8 +3220,8 @@ export const DebugDrawer: React.FC<{
if (type === 'assistant.response.delta') {
const delta = String(payload.text || '');
if (!delta) return;
const responseIdRaw = payload?.data?.response_id ?? payload?.response_id ?? payload?.responseId;
const responseId = String(responseIdRaw || '').trim() || undefined;
const responseId = extractResponseId(payload);
if (responseId && interruptedResponseIdsRef.current.has(responseId)) return;
setMessages((prev) => {
let idx = assistantDraftIndexRef.current;
if (idx === null || !prev[idx] || prev[idx].role !== 'model') {
@@ -3250,8 +3287,8 @@ export const DebugDrawer: React.FC<{
if (type === 'assistant.response.final') {
const finalText = String(payload.text || '');
const responseIdRaw = payload?.data?.response_id ?? payload?.response_id ?? payload?.responseId;
const responseId = String(responseIdRaw || '').trim() || undefined;
const responseId = extractResponseId(payload);
if (responseId && interruptedResponseIdsRef.current.has(responseId)) return;
setMessages((prev) => {
let idx = assistantDraftIndexRef.current;
assistantDraftIndexRef.current = null;
@@ -3771,6 +3808,29 @@ export const DebugDrawer: React.FC<{
</div>
</div>
</div>
{textPromptDialog.open && (
<div className="absolute inset-0 z-40 flex items-center justify-center bg-black/55 backdrop-blur-[1px]">
<div className="relative w-[92%] max-w-md rounded-xl border border-white/15 bg-card/95 p-4 shadow-2xl animate-in zoom-in-95 duration-200">
<button
type="button"
onClick={() => setTextPromptDialog({ open: false, message: '' })}
className="absolute right-3 top-3 rounded-sm opacity-70 hover:opacity-100 text-muted-foreground hover:text-foreground transition-opacity"
title="关闭"
>
<X className="h-4 w-4" />
</button>
<div className="mb-3 pr-6">
<div className="text-[10px] font-black tracking-[0.14em] uppercase text-amber-300"></div>
<p className="mt-2 text-sm leading-6 text-foreground whitespace-pre-wrap break-words">{textPromptDialog.message}</p>
</div>
<div className="flex justify-end">
<Button size="sm" onClick={() => setTextPromptDialog({ open: false, message: '' })}>
</Button>
</div>
</div>
</div>
)}
</div>
</Drawer>
{isOpen && (