Unify db api

This commit is contained in:
Xin Wang
2026-02-26 01:58:39 +08:00
parent 56f8aa2191
commit 72ed7d0512
40 changed files with 3926 additions and 593 deletions

View File

@@ -14,7 +14,8 @@ event-driven design.
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import uuid
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -59,6 +60,12 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_ASR_DELTA_THROTTLE_MS = 300
_LLM_DELTA_THROTTLE_MS = 80
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
"name": "current_time",
@@ -79,7 +86,16 @@ class DuplexPipeline:
tts_service: Optional[BaseTTSService] = None,
asr_service: Optional[BaseASRService] = None,
system_prompt: Optional[str] = None,
greeting: Optional[str] = None
greeting: Optional[str] = None,
knowledge_searcher: Optional[
Callable[..., Awaitable[List[Dict[str, Any]]]]
] = None,
tool_resource_resolver: Optional[
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
] = None,
server_tool_executor: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = None,
):
"""
Initialize duplex pipeline.
@@ -96,6 +112,9 @@ class DuplexPipeline:
self.transport = transport
self.session_id = session_id
self.event_bus = get_event_bus()
self.track_audio_in = self.TRACK_AUDIO_IN
self.track_audio_out = self.TRACK_AUDIO_OUT
self.track_control = self.TRACK_CONTROL
# Initialize VAD
self.vad_model = SileroVAD(
@@ -117,9 +136,14 @@ class DuplexPipeline:
self.llm_service = llm_service
self.tts_service = tts_service
self.asr_service = asr_service # Will be initialized in start()
self._knowledge_searcher = knowledge_searcher
self._tool_resource_resolver = tool_resource_resolver
self._server_tool_executor = server_tool_executor
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
# Conversation manager
self.conversation = ConversationManager(
@@ -153,6 +177,7 @@ class DuplexPipeline:
self._outbound_seq = 0
self._outbound_task: Optional[asyncio.Task] = None
self._drop_outbound_audio = False
self._audio_out_frame_buffer: bytes = b""
# Interruption handling
self._interrupt_event = asyncio.Event()
@@ -181,14 +206,48 @@ class DuplexPipeline:
self._runtime_barge_in_min_duration_ms: Optional[int] = None
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
self._runtime_tools: List[Any] = []
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
self._runtime_tools: List[Any] = list(raw_default_tools)
self._runtime_tool_executor: Dict[str, str] = {}
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
# Cross-service correlation IDs
self._turn_count: int = 0
self._response_count: int = 0
self._tts_count: int = 0
self._utterance_count: int = 0
self._current_turn_id: Optional[str] = None
self._current_utterance_id: Optional[str] = None
self._current_response_id: Optional[str] = None
self._current_tts_id: Optional[str] = None
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
self._runtime_tool_executor = self._resolved_tool_executor_map()
if self._server_tool_executor is None:
if self._tool_resource_resolver:
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
return await execute_server_tool(
call,
tool_resource_fetcher=self._tool_resource_resolver,
)
self._server_tool_executor = _executor
else:
self._server_tool_executor = execute_server_tool
logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
"""Use session-scoped monotonic sequence provider for envelope events."""
self._next_seq = provider
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
@@ -276,6 +335,136 @@ class DuplexPipeline:
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
llm_base_url = (
self._runtime_llm.get("baseUrl")
or settings.llm_api_url
or self._default_llm_base_url(llm_provider)
)
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
if not output_mode:
output_mode = "audio" if self._tts_output_enabled() else "text"
return {
"output": {"mode": output_mode},
"services": {
"llm": {
"provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": llm_base_url,
},
"asr": {
"provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
"tts": {
"enabled": self._tts_output_enabled(),
"provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
},
},
"tools": {
"allowlist": self._resolved_tool_allowlist(),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
"control": self.track_control,
},
}
def _next_event_seq(self) -> int:
if self._next_seq:
return self._next_seq()
self._local_seq += 1
return self._local_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
return "asr"
if event_type.startswith("assistant.response."):
return "llm"
if event_type.startswith("assistant.tool_"):
return "tool"
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
return "tts"
return "system"
def _new_id(self, prefix: str, counter: int) -> str:
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
def _start_turn(self) -> str:
self._turn_count += 1
self._current_turn_id = self._new_id("turn", self._turn_count)
self._current_utterance_id = None
self._current_response_id = None
self._current_tts_id = None
return self._current_turn_id
def _start_response(self) -> str:
self._response_count += 1
self._current_response_id = self._new_id("resp", self._response_count)
self._current_tts_id = None
return self._current_response_id
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
return self._current_utterance_id
self._utterance_count += 1
self._current_utterance_id = self._new_id("utt", self._utterance_count)
if not self._current_turn_id:
self._start_turn()
return self._current_utterance_id
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId")
if not track_id:
if source == "asr":
track_id = self.track_audio_in
elif source in {"llm", "tts", "tool"}:
track_id = self.track_audio_out
else:
track_id = self.track_control
data = event.get("data")
if not isinstance(data, dict):
data = {}
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
data.setdefault("utterance_id", self._current_utterance_id)
if self._current_response_id:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.session_id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
@@ -295,6 +484,18 @@ class DuplexPipeline:
normalized = str(provider or "").strip().lower()
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
@staticmethod
def _is_llm_provider_supported(provider: Any) -> bool:
normalized = str(provider or "").strip().lower()
return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"}
@staticmethod
def _default_llm_base_url(provider: Any) -> Optional[str]:
normalized = str(provider or "").strip().lower()
if normalized == "siliconflow":
return "https://api.siliconflow.cn/v1"
return None
def _tts_output_enabled(self) -> bool:
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
if enabled is not None:
@@ -370,20 +571,25 @@ class DuplexPipeline:
try:
# Connect LLM service
if not self.llm_service:
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key
llm_base_url = (
self._runtime_llm.get("baseUrl")
or settings.llm_api_url
or self._default_llm_base_url(llm_provider)
)
llm_model = self._runtime_llm.get("model") or settings.llm_model
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
if llm_provider == "openai" and llm_api_key:
if self._is_llm_provider_supported(llm_provider) and llm_api_key:
self.llm_service = OpenAILLMService(
api_key=llm_api_key,
base_url=llm_base_url,
model=llm_model,
knowledge_config=self._resolved_knowledge_config(),
knowledge_searcher=self._knowledge_searcher,
)
else:
logger.warning("No OpenAI API key - using mock LLM")
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
self.llm_service = MockLLMService()
if hasattr(self.llm_service, "set_knowledge_config"):
@@ -399,20 +605,22 @@ class DuplexPipeline:
if tts_output_enabled:
if not self.tts_service:
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
tts_model = self._runtime_tts.get("model") or settings.tts_model
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
self.tts_service = OpenAICompatibleTTSService(
api_key=tts_api_key,
api_url=tts_api_url,
voice=tts_voice,
model=tts_model,
model=tts_model or "FunAudioLLM/CosyVoice2-0.5B",
sample_rate=settings.sample_rate,
speed=tts_speed
)
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)")
logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})")
else:
self.tts_service = EdgeTTSService(
voice=tts_voice,
@@ -435,21 +643,23 @@ class DuplexPipeline:
# Connect ASR service
if not self.asr_service:
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
self.asr_service = OpenAICompatibleASRService(
api_key=asr_api_key,
model=asr_model,
api_url=asr_api_url,
model=asr_model or "FunAudioLLM/SenseVoiceSmall",
sample_rate=settings.sample_rate,
interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback
)
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)")
logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})")
else:
self.asr_service = BufferedASRService(
sample_rate=settings.sample_rate
@@ -472,11 +682,13 @@ class DuplexPipeline:
greeting_to_speak = generated_greeting
self.conversation.greeting = generated_greeting
if greeting_to_speak:
self._start_turn()
self._start_response()
await self._send_event(
ev(
"assistant.response.final",
text=greeting_to_speak,
trackId=self.session_id,
trackId=self.track_audio_out,
),
priority=20,
)
@@ -494,10 +706,58 @@ class DuplexPipeline:
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
await self._enqueue_outbound("event", event, priority)
await self._enqueue_outbound("event", self._envelope_event(event), priority)
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
await self._enqueue_outbound("audio", pcm_bytes, priority)
if not pcm_bytes:
return
self._audio_out_frame_buffer += pcm_bytes
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
await self._enqueue_outbound("audio", frame, priority)
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
if not self._audio_out_frame_buffer:
return
tail = self._audio_out_frame_buffer
self._audio_out_frame_buffer = b""
if len(tail) < self._PCM_FRAME_BYTES:
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
await self._enqueue_outbound("audio", tail, priority)
async def _emit_transcript_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"transcript.delta",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events."""
@@ -546,13 +806,13 @@ class DuplexPipeline:
# Emit VAD event
await self.event_bus.publish(event_type, {
"trackId": self.session_id,
"trackId": self.track_audio_in,
"probability": probability
})
await self._send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
trackId=self.track_audio_in,
probability=probability,
),
priority=30,
@@ -661,6 +921,9 @@ class DuplexPipeline:
# Cancel any current speaking
await self._stop_current_speech()
self._start_turn()
self._finalize_utterance()
# Start new turn
await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
@@ -683,24 +946,45 @@ class DuplexPipeline:
if text == self._last_sent_transcript and not is_final:
return
now_ms = time.monotonic() * 1000.0
self._last_sent_transcript = text
# Send transcript event to client
await self._send_event({
**ev(
"transcript.final" if is_final else "transcript.delta",
trackId=self.session_id,
text=text,
if is_final:
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
{
**ev(
"transcript.final",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
}, priority=30)
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
)
if should_emit and self._pending_transcript_delta:
delta = self._pending_transcript_delta
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = now_ms
await self._emit_transcript_delta(delta)
if not is_final:
logger.info(f"[ASR] ASR interim: {text[:100]}")
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
logger.debug(f"Sent transcript (interim): {text[:50]}...")
async def _on_speech_start(self) -> None:
"""Handle user starting to speak."""
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
self._start_turn()
self._finalize_utterance()
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
@@ -779,6 +1063,7 @@ class DuplexPipeline:
return
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
self._finalize_utterance()
# For ASR backends that already emitted final via callback,
# avoid duplicating transcript.final on EOU.
@@ -786,7 +1071,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"transcript.final",
trackId=self.session_id,
trackId=self.track_audio_in,
text=user_text,
)
}, priority=25)
@@ -794,6 +1079,8 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False
self._pending_speech_audio = b""
@@ -881,6 +1168,23 @@ class DuplexPipeline:
result[name] = executor
return result
def _resolved_tool_allowlist(self) -> List[str]:
names: set[str] = set()
for item in self._runtime_tools:
if isinstance(item, str):
name = item.strip()
if name:
names.add(name)
continue
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
names.add(str(fn.get("name")).strip())
elif item.get("name"):
names.add(str(item.get("name")).strip())
return sorted([name for name in names if name])
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
fn = tool_call.get("function")
if isinstance(fn, dict):
@@ -894,6 +1198,44 @@ class DuplexPipeline:
# Default to server execution unless explicitly marked as client.
return "server"
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
fn = tool_call.get("function")
if not isinstance(fn, dict):
return {}
raw = fn.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str) and raw.strip():
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"raw": raw}
except Exception:
return {"raw": raw}
return {}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
if not ok:
error = {
"code": status_code or 500,
"message": status_message or "tool_execution_failed",
"retryable": retryable,
}
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
}
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool")
call_id = str(result.get("tool_call_id") or result.get("id") or "")
@@ -904,12 +1246,17 @@ class DuplexPipeline:
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
f"status={status_code} {status_message}".strip()
)
normalized = self._normalize_tool_result(result)
await self._send_event(
{
**ev(
"assistant.tool_result",
trackId=self.session_id,
trackId=self.track_audio_out,
source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result,
)
},
@@ -927,6 +1274,9 @@ class DuplexPipeline:
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
continue
if call_id in self._completed_tool_call_ids:
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
continue
@@ -972,6 +1322,7 @@ class DuplexPipeline:
}
finally:
self._pending_tool_waiters.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
@@ -998,6 +1349,11 @@ class DuplexPipeline:
user_text: User's transcribed text
"""
try:
if not self._current_turn_id:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
# Start latency tracking
self._turn_start_time = time.time()
self._first_audio_sent = False
@@ -1012,6 +1368,8 @@ class DuplexPipeline:
self._drop_outbound_audio = False
first_audio_sent = False
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
@@ -1028,6 +1386,7 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
await self._flush_pending_llm_delta()
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
@@ -1045,11 +1404,19 @@ class DuplexPipeline:
f"executor={executor} args={args_preview}"
)
tool_calls.append(enriched_tool_call)
tool_arguments = self._tool_arguments(enriched_tool_call)
if executor == "client" and call_id:
self._pending_client_tool_call_ids.add(call_id)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call=enriched_tool_call,
)
},
@@ -1071,19 +1438,13 @@ class DuplexPipeline:
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
# Keep delta/final on the same event priority so FIFO seq
# preserves stream order (avoid late-delta after final).
priority=20,
)
self._pending_llm_delta += text_chunk
now_ms = time.monotonic() * 1000.0
if (
self._last_llm_delta_emit_ms <= 0.0
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
):
await self._flush_pending_llm_delta()
while True:
split_result = extract_tts_sentence(
@@ -1112,11 +1473,12 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1130,6 +1492,7 @@ class DuplexPipeline:
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
await self._flush_pending_llm_delta()
if (
self._tts_output_enabled()
and remaining_text
@@ -1137,11 +1500,12 @@ class DuplexPipeline:
and not self._interrupt_event.is_set()
):
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1172,7 +1536,7 @@ class DuplexPipeline:
try:
result = await asyncio.wait_for(
execute_server_tool(call),
self._server_tool_executor(call),
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
@@ -1204,11 +1568,12 @@ class DuplexPipeline:
]
if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.session_id,
trackId=self.track_audio_out,
text=full_response,
)
},
@@ -1217,10 +1582,11 @@ class DuplexPipeline:
# Send track end
if first_audio_sent:
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1241,6 +1607,8 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
"""
@@ -1277,7 +1645,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1354,10 +1722,11 @@ class DuplexPipeline:
first_audio_sent = False
# Send track start event
self._start_tts()
await self._send_event({
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1379,7 +1748,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1391,10 +1760,11 @@ class DuplexPipeline:
await asyncio.sleep(0.01)
# Send track end event
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1422,13 +1792,14 @@ class DuplexPipeline:
self._interrupt_event.set()
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self._send_event({
**ev(
"response.interrupted",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=0)
@@ -1455,6 +1826,7 @@ class DuplexPipeline:
async def _stop_current_speech(self) -> None:
"""Stop any current speech task."""
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set()
self._current_turn_task.cancel()

View File

@@ -0,0 +1,244 @@
"""Async history bridge for non-blocking transcript persistence."""
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass
from typing import Any, Optional
from loguru import logger
@dataclass
class _HistoryTranscriptJob:
call_id: str
turn_index: int
speaker: str
content: str
start_ms: int
end_ms: int
duration_ms: int
class SessionHistoryBridge:
"""Session-scoped buffered history writer with background retries."""
_STOP_SENTINEL = object()
def __init__(
self,
*,
history_writer: Any,
enabled: bool,
queue_max_size: int,
retry_max_attempts: int,
retry_backoff_sec: float,
finalize_drain_timeout_sec: float,
):
self._history_writer = history_writer
self._enabled = bool(enabled and history_writer is not None)
self._queue_max_size = max(1, int(queue_max_size))
self._retry_max_attempts = max(0, int(retry_max_attempts))
self._retry_backoff_sec = max(0.0, float(retry_backoff_sec))
self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec))
self._call_id: Optional[str] = None
self._turn_index: int = 0
self._started_mono: Optional[float] = None
self._finalized: bool = False
self._worker_task: Optional[asyncio.Task] = None
self._finalize_lock = asyncio.Lock()
self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size)
@property
def enabled(self) -> bool:
return self._enabled
@property
def call_id(self) -> Optional[str]:
return self._call_id
async def start_call(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str,
) -> Optional[str]:
"""Create remote call record and start background worker."""
if not self._enabled or self._call_id:
return self._call_id
call_id = await self._history_writer.create_call_record(
user_id=user_id,
assistant_id=assistant_id,
source=source,
)
if not call_id:
return None
self._call_id = str(call_id)
self._turn_index = 0
self._finalized = False
self._started_mono = time.monotonic()
self._ensure_worker()
return self._call_id
def elapsed_ms(self) -> int:
if self._started_mono is None:
return 0
return max(0, int((time.monotonic() - self._started_mono) * 1000))
def enqueue_turn(self, *, role: str, text: str) -> bool:
"""Queue one transcript write without blocking the caller."""
if not self._enabled or not self._call_id or self._finalized:
return False
content = str(text or "").strip()
if not content:
return False
speaker = "human" if str(role or "").strip().lower() == "user" else "ai"
end_ms = self.elapsed_ms()
estimated_duration_ms = max(300, min(12000, len(content) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
job = _HistoryTranscriptJob(
call_id=self._call_id,
turn_index=self._turn_index,
speaker=speaker,
content=content,
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._turn_index += 1
self._ensure_worker()
try:
self._queue.put_nowait(job)
return True
except asyncio.QueueFull:
logger.warning(
"History queue full; dropping transcript call_id={} turn={}",
self._call_id,
job.turn_index,
)
return False
async def finalize(self, *, status: str) -> bool:
"""Finalize history record once; waits briefly for queue drain."""
if not self._enabled or not self._call_id:
return False
async with self._finalize_lock:
if self._finalized:
return True
await self._drain_queue()
ok = await self._history_writer.finalize_call_record(
call_id=self._call_id,
status=status,
duration_seconds=self.duration_seconds(),
)
if ok:
self._finalized = True
await self._stop_worker()
return ok
async def shutdown(self) -> None:
"""Stop worker task and release queue resources."""
await self._stop_worker()
def duration_seconds(self) -> int:
if self._started_mono is None:
return 0
return max(0, int(time.monotonic() - self._started_mono))
def _ensure_worker(self) -> None:
if self._worker_task and not self._worker_task.done():
return
self._worker_task = asyncio.create_task(self._worker_loop())
async def _drain_queue(self) -> None:
if self._finalize_drain_timeout_sec <= 0:
return
try:
await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec)
except asyncio.TimeoutError:
logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec)
async def _stop_worker(self) -> None:
task = self._worker_task
if not task:
return
if task.done():
self._worker_task = None
return
sent = False
try:
self._queue.put_nowait(self._STOP_SENTINEL)
sent = True
except asyncio.QueueFull:
pass
if not sent:
try:
await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5)
except asyncio.TimeoutError:
task.cancel()
try:
await asyncio.wait_for(task, timeout=1.5)
except asyncio.TimeoutError:
task.cancel()
try:
await task
except Exception:
pass
except asyncio.CancelledError:
pass
finally:
self._worker_task = None
async def _worker_loop(self) -> None:
while True:
item = await self._queue.get()
try:
if item is self._STOP_SENTINEL:
return
assert isinstance(item, _HistoryTranscriptJob)
await self._write_with_retry(item)
except Exception as exc:
logger.warning("History worker write failed unexpectedly: {}", exc)
finally:
self._queue.task_done()
async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool:
for attempt in range(self._retry_max_attempts + 1):
ok = await self._history_writer.add_transcript(
call_id=job.call_id,
turn_index=job.turn_index,
speaker=job.speaker,
content=job.content,
start_ms=job.start_ms,
end_ms=job.end_ms,
duration_ms=job.duration_ms,
)
if ok:
return True
if attempt >= self._retry_max_attempts:
logger.warning(
"History write dropped after retries call_id={} turn={}",
job.call_id,
job.turn_index,
)
return False
if self._retry_backoff_sec > 0:
await asyncio.sleep(self._retry_backoff_sec * (2**attempt))
return False

View File

@@ -0,0 +1,17 @@
"""Port interfaces for engine-side integration boundaries."""
from core.ports.backend import (
AssistantConfigProvider,
BackendGateway,
HistoryWriter,
KnowledgeSearcher,
ToolResourceResolver,
)
__all__ = [
"AssistantConfigProvider",
"BackendGateway",
"HistoryWriter",
"KnowledgeSearcher",
"ToolResourceResolver",
]

View File

@@ -0,0 +1,84 @@
"""Backend integration ports.
These interfaces define the boundary between engine runtime logic and
backend-side capabilities (config lookup, history persistence, retrieval,
and tool resource discovery).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Protocol
class AssistantConfigProvider(Protocol):
"""Port for loading trusted assistant runtime configuration."""
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
"""Fetch assistant configuration payload."""
class HistoryWriter(Protocol):
"""Port for persisting call and transcript history."""
async def create_call_record(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str = "debug",
) -> Optional[str]:
"""Create a call record and return backend call ID."""
async def add_transcript(
self,
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> bool:
"""Append one transcript turn segment."""
async def finalize_call_record(
self,
*,
call_id: str,
status: str,
duration_seconds: int,
) -> bool:
"""Finalize a call record."""
class KnowledgeSearcher(Protocol):
"""Port for RAG / knowledge retrieval operations."""
async def search_knowledge_context(
self,
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search a knowledge source and return ranked snippets."""
class ToolResourceResolver(Protocol):
"""Port for resolving tool metadata/configuration."""
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
"""Fetch tool resource configuration."""
class BackendGateway(
AssistantConfigProvider,
HistoryWriter,
KnowledgeSearcher,
ToolResourceResolver,
Protocol,
):
"""Composite backend gateway interface used by engine services."""

View File

@@ -1,22 +1,19 @@
"""Session management for active calls."""
import asyncio
import uuid
import hashlib
import json
import time
import re
import time
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
)
from app.backend_adapters import build_backend_adapter_from_settings
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.history_bridge import SessionHistoryBridge
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
@@ -49,7 +46,39 @@ class Session:
Uses full duplex voice conversation pipeline.
"""
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"history",
"userId",
"assistantId",
"source",
}
_CLIENT_METADATA_ID_KEYS = {
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
}
def __init__(
self,
session_id: str,
transport: BaseTransport,
use_duplex: bool = None,
backend_gateway: Optional[Any] = None,
):
"""
Initialize session.
@@ -61,12 +90,23 @@ class Session:
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
self._history_bridge = SessionHistoryBridge(
history_writer=self._backend_gateway,
enabled=settings.history_enabled,
queue_max_size=settings.history_queue_max_size,
retry_max_attempts=settings.history_retry_max_attempts,
retry_backoff_sec=settings.history_retry_backoff_sec,
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
)
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting
greeting=settings.duplex_greeting,
knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None),
tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None),
)
# Session state
@@ -78,17 +118,15 @@ class Session:
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
self._history_finalized: bool = False
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
@@ -129,13 +167,47 @@ class Session:
"client",
"Audio received before session.start",
"protocol.order",
stage="protocol",
retryable=False,
)
return
try:
await self.pipeline.process_audio(audio_bytes)
if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = self.AUDIO_FRAME_BYTES
if len(audio_bytes) % frame_bytes != 0:
await self._send_error(
"client",
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=False,
)
return
for i in range(0, len(audio_bytes), frame_bytes):
frame = audio_bytes[i : i + frame_bytes]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
@@ -176,7 +248,7 @@ class Session:
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
@@ -198,9 +270,9 @@ class Session:
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
auth_payload = message.auth
api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key:
if api_key != settings.ws_api_key:
@@ -217,10 +289,9 @@ class Session:
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
await self._send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
@@ -231,8 +302,12 @@ class Session:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -248,28 +323,37 @@ class Session:
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
await self._send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio.model_dump() if message.audio else {},
)
)
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
await self._send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
@@ -285,17 +369,24 @@ class Session:
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
await self._send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None:
"""
Send error event to client.
@@ -304,13 +395,26 @@ class Session:
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
trackId=self.current_track_id,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=resolved_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
)
)
@@ -329,11 +433,12 @@ class Session:
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self._history_bridge.shutdown()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_call_id:
if self._history_bridge.call_id:
return
history_meta: Dict[str, Any] = {}
@@ -349,7 +454,7 @@ class Session:
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await create_history_call_record(
call_id = await self._history_bridge.start_call(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
@@ -357,10 +462,6 @@ class Session:
if not call_id:
return
self._history_call_id = call_id
self._history_call_started_mono = time.monotonic()
self._history_turn_index = 0
self._history_finalized = False
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
@@ -372,48 +473,11 @@ class Session:
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
return
role = (turn.role or "").lower()
speaker = "human" if role == "user" else "ai"
end_ms = 0
if self._history_call_started_mono is not None:
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
turn_index = self._history_turn_index
await add_history_transcript(
call_id=self._history_call_id,
turn_index=turn_index,
speaker=speaker,
content=turn.text.strip(),
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._history_turn_index += 1
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
if not self._history_call_id or self._history_finalized:
return
duration_seconds = 0
if self._history_call_started_mono is not None:
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
ok = await finalize_history_call_record(
call_id=self._history_call_id,
status=status,
duration_seconds=duration_seconds,
)
if ok:
self._history_finalized = True
await self._history_bridge.finalize(status=status)
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
@@ -483,10 +547,9 @@ class Session:
node = transition.node
edge = transition.edge
await self.transport.send_event(
await self._send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
@@ -494,10 +557,9 @@ class Session:
reason=reason,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
@@ -510,10 +572,9 @@ class Session:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
@@ -522,10 +583,9 @@ class Session:
return
if node.node_type == "human_transfer":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
@@ -534,16 +594,77 @@ class Session:
return
if node.node_type == "end":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
@@ -629,6 +750,137 @@ class Session:
merged[key] = value
return merged
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
if not callable(provider):
return {}
payload = await provider(str(assistant_id).strip())
if not isinstance(payload, dict):
return {}
assistant_cfg: Dict[str, Any] = {}
session_start_cfg = payload.get("sessionStartMetadata")
if isinstance(session_start_cfg, dict):
assistant_cfg.update(session_start_cfg)
if isinstance(payload.get("assistant"), dict):
assistant_cfg.update(payload.get("assistant"))
elif not assistant_cfg:
assistant_cfg = payload
if not isinstance(assistant_cfg, dict):
return {}
runtime: Dict[str, Any] = {}
passthrough_keys = {
"firstTurnMode",
"generatedOpenerEnabled",
"output",
"bargeIn",
"knowledgeBaseId",
"knowledge",
"history",
"userId",
"source",
"tools",
"services",
"configVersionId",
"config_version_id",
}
for key in passthrough_keys:
if key in assistant_cfg:
runtime[key] = assistant_cfg[key]
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
resolved_assistant_id = (
assistant_cfg.get("assistantId")
or payload.get("assistantId")
or assistant_id
)
runtime["assistantId"] = str(resolved_assistant_id)
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
runtime["configVersionId"] = payload.get("configVersionId")
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
runtime["configVersionId"] = payload.get("config_version_id")
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
runtime["configVersionId"] = runtime.get("config_version_id")
return runtime
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
This keeps only a small override whitelist and stable config ID fields.
"""
if not isinstance(metadata, dict):
return {}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
return sanitized
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},
"output": runtime.get("output", {}),
"services": runtime.get("services", {}),
"tools": runtime.get("tools", {}),
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:

View File

@@ -4,11 +4,13 @@ import asyncio
import ast
import operator
from datetime import datetime
from typing import Any, Dict
from typing import Any, Awaitable, Callable, Dict, Optional
import aiohttp
from app.backend_client import fetch_tool_resource
from app.backend_adapters import build_backend_adapter_from_settings
ToolResourceFetcher = Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
_BIN_OPS = {
ast.Add: operator.add,
@@ -170,11 +172,21 @@ def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
return {}
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
"""Default tool resource resolver via backend adapter."""
adapter = build_backend_adapter_from_settings()
return await adapter.fetch_tool_resource(tool_id)
async def execute_server_tool(
tool_call: Dict[str, Any],
tool_resource_fetcher: Optional[ToolResourceFetcher] = None,
) -> Dict[str, Any]:
"""Execute a server-side tool and return normalized result payload."""
call_id = str(tool_call.get("id") or "").strip()
tool_name = _extract_tool_name(tool_call)
args = _extract_tool_args(tool_call)
resource_fetcher = tool_resource_fetcher or fetch_tool_resource
if tool_name == "calculator":
expression = str(args.get("expression") or "").strip()
@@ -257,7 +269,7 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
}
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
resource = await fetch_tool_resource(tool_name)
resource = await resource_fetcher(tool_name)
if resource and str(resource.get("category") or "") == "query":
method = str(resource.get("http_method") or "GET").strip().upper()
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}: