Unify db api
This commit is contained in:
@@ -14,7 +14,8 @@ event-driven design.
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -59,6 +60,12 @@ class DuplexPipeline:
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
_ASR_DELTA_THROTTLE_MS = 300
|
||||
_LLM_DELTA_THROTTLE_MS = 80
|
||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||
"current_time": {
|
||||
"name": "current_time",
|
||||
@@ -79,7 +86,16 @@ class DuplexPipeline:
|
||||
tts_service: Optional[BaseTTSService] = None,
|
||||
asr_service: Optional[BaseASRService] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
greeting: Optional[str] = None
|
||||
greeting: Optional[str] = None,
|
||||
knowledge_searcher: Optional[
|
||||
Callable[..., Awaitable[List[Dict[str, Any]]]]
|
||||
] = None,
|
||||
tool_resource_resolver: Optional[
|
||||
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||||
] = None,
|
||||
server_tool_executor: Optional[
|
||||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
@@ -96,6 +112,9 @@ class DuplexPipeline:
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
self.track_audio_in = self.TRACK_AUDIO_IN
|
||||
self.track_audio_out = self.TRACK_AUDIO_OUT
|
||||
self.track_control = self.TRACK_CONTROL
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
@@ -117,9 +136,14 @@ class DuplexPipeline:
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service # Will be initialized in start()
|
||||
self._knowledge_searcher = knowledge_searcher
|
||||
self._tool_resource_resolver = tool_resource_resolver
|
||||
self._server_tool_executor = server_tool_executor
|
||||
|
||||
# Track last sent transcript to avoid duplicates
|
||||
self._last_sent_transcript = ""
|
||||
self._pending_transcript_delta: str = ""
|
||||
self._last_transcript_delta_emit_ms: float = 0.0
|
||||
|
||||
# Conversation manager
|
||||
self.conversation = ConversationManager(
|
||||
@@ -153,6 +177,7 @@ class DuplexPipeline:
|
||||
self._outbound_seq = 0
|
||||
self._outbound_task: Optional[asyncio.Task] = None
|
||||
self._drop_outbound_audio = False
|
||||
self._audio_out_frame_buffer: bytes = b""
|
||||
|
||||
# Interruption handling
|
||||
self._interrupt_event = asyncio.Event()
|
||||
@@ -181,14 +206,48 @@ class DuplexPipeline:
|
||||
self._runtime_barge_in_min_duration_ms: Optional[int] = None
|
||||
self._runtime_knowledge: Dict[str, Any] = {}
|
||||
self._runtime_knowledge_base_id: Optional[str] = None
|
||||
self._runtime_tools: List[Any] = []
|
||||
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
|
||||
self._runtime_tools: List[Any] = list(raw_default_tools)
|
||||
self._runtime_tool_executor: Dict[str, str] = {}
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
self._pending_client_tool_call_ids: set[str] = set()
|
||||
self._next_seq: Optional[Callable[[], int]] = None
|
||||
self._local_seq: int = 0
|
||||
|
||||
# Cross-service correlation IDs
|
||||
self._turn_count: int = 0
|
||||
self._response_count: int = 0
|
||||
self._tts_count: int = 0
|
||||
self._utterance_count: int = 0
|
||||
self._current_turn_id: Optional[str] = None
|
||||
self._current_utterance_id: Optional[str] = None
|
||||
self._current_response_id: Optional[str] = None
|
||||
self._current_tts_id: Optional[str] = None
|
||||
self._pending_llm_delta: str = ""
|
||||
self._last_llm_delta_emit_ms: float = 0.0
|
||||
|
||||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||||
|
||||
if self._server_tool_executor is None:
|
||||
if self._tool_resource_resolver:
|
||||
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return await execute_server_tool(
|
||||
call,
|
||||
tool_resource_fetcher=self._tool_resource_resolver,
|
||||
)
|
||||
|
||||
self._server_tool_executor = _executor
|
||||
else:
|
||||
self._server_tool_executor = execute_server_tool
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
|
||||
"""Use session-scoped monotonic sequence provider for envelope events."""
|
||||
self._next_seq = provider
|
||||
|
||||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Apply runtime overrides from WS session.start metadata.
|
||||
@@ -276,6 +335,136 @@ class DuplexPipeline:
|
||||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
def resolved_runtime_config(self) -> Dict[str, Any]:
|
||||
"""Return current effective runtime configuration without secrets."""
|
||||
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||||
llm_base_url = (
|
||||
self._runtime_llm.get("baseUrl")
|
||||
or settings.llm_api_url
|
||||
or self._default_llm_base_url(llm_provider)
|
||||
)
|
||||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||||
if not output_mode:
|
||||
output_mode = "audio" if self._tts_output_enabled() else "text"
|
||||
|
||||
return {
|
||||
"output": {"mode": output_mode},
|
||||
"services": {
|
||||
"llm": {
|
||||
"provider": llm_provider,
|
||||
"model": str(self._runtime_llm.get("model") or settings.llm_model),
|
||||
"baseUrl": llm_base_url,
|
||||
},
|
||||
"asr": {
|
||||
"provider": asr_provider,
|
||||
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
|
||||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||||
},
|
||||
"tts": {
|
||||
"enabled": self._tts_output_enabled(),
|
||||
"provider": tts_provider,
|
||||
"model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
|
||||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||||
},
|
||||
},
|
||||
"tools": {
|
||||
"allowlist": self._resolved_tool_allowlist(),
|
||||
},
|
||||
"tracks": {
|
||||
"audio_in": self.track_audio_in,
|
||||
"audio_out": self.track_audio_out,
|
||||
"control": self.track_control,
|
||||
},
|
||||
}
|
||||
|
||||
def _next_event_seq(self) -> int:
|
||||
if self._next_seq:
|
||||
return self._next_seq()
|
||||
self._local_seq += 1
|
||||
return self._local_seq
|
||||
|
||||
def _event_source(self, event_type: str) -> str:
|
||||
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
|
||||
return "asr"
|
||||
if event_type.startswith("assistant.response."):
|
||||
return "llm"
|
||||
if event_type.startswith("assistant.tool_"):
|
||||
return "tool"
|
||||
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
|
||||
return "tts"
|
||||
return "system"
|
||||
|
||||
def _new_id(self, prefix: str, counter: int) -> str:
|
||||
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _start_turn(self) -> str:
|
||||
self._turn_count += 1
|
||||
self._current_turn_id = self._new_id("turn", self._turn_count)
|
||||
self._current_utterance_id = None
|
||||
self._current_response_id = None
|
||||
self._current_tts_id = None
|
||||
return self._current_turn_id
|
||||
|
||||
def _start_response(self) -> str:
|
||||
self._response_count += 1
|
||||
self._current_response_id = self._new_id("resp", self._response_count)
|
||||
self._current_tts_id = None
|
||||
return self._current_response_id
|
||||
|
||||
def _start_tts(self) -> str:
|
||||
self._tts_count += 1
|
||||
self._current_tts_id = self._new_id("tts", self._tts_count)
|
||||
return self._current_tts_id
|
||||
|
||||
def _finalize_utterance(self) -> str:
|
||||
if self._current_utterance_id:
|
||||
return self._current_utterance_id
|
||||
self._utterance_count += 1
|
||||
self._current_utterance_id = self._new_id("utt", self._utterance_count)
|
||||
if not self._current_turn_id:
|
||||
self._start_turn()
|
||||
return self._current_utterance_id
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
track_id = event.get("trackId")
|
||||
if not track_id:
|
||||
if source == "asr":
|
||||
track_id = self.track_audio_in
|
||||
elif source in {"llm", "tts", "tool"}:
|
||||
track_id = self.track_audio_out
|
||||
else:
|
||||
track_id = self.track_control
|
||||
|
||||
data = event.get("data")
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
if self._current_turn_id:
|
||||
data.setdefault("turn_id", self._current_turn_id)
|
||||
if self._current_utterance_id:
|
||||
data.setdefault("utterance_id", self._current_utterance_id)
|
||||
if self._current_response_id:
|
||||
data.setdefault("response_id", self._current_response_id)
|
||||
if self._current_tts_id:
|
||||
data.setdefault("tts_id", self._current_tts_id)
|
||||
|
||||
for k, v in event.items():
|
||||
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
|
||||
continue
|
||||
data.setdefault(k, v)
|
||||
|
||||
event["sessionId"] = self.session_id
|
||||
event["seq"] = self._next_event_seq()
|
||||
event["source"] = source
|
||||
event["trackId"] = track_id
|
||||
event["data"] = data
|
||||
return event
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||||
if isinstance(value, bool):
|
||||
@@ -295,6 +484,18 @@ class DuplexPipeline:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
@staticmethod
|
||||
def _is_llm_provider_supported(provider: Any) -> bool:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
@staticmethod
|
||||
def _default_llm_base_url(provider: Any) -> Optional[str]:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
if normalized == "siliconflow":
|
||||
return "https://api.siliconflow.cn/v1"
|
||||
return None
|
||||
|
||||
def _tts_output_enabled(self) -> bool:
|
||||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||||
if enabled is not None:
|
||||
@@ -370,20 +571,25 @@ class DuplexPipeline:
|
||||
try:
|
||||
# Connect LLM service
|
||||
if not self.llm_service:
|
||||
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
|
||||
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
|
||||
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||||
llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key
|
||||
llm_base_url = (
|
||||
self._runtime_llm.get("baseUrl")
|
||||
or settings.llm_api_url
|
||||
or self._default_llm_base_url(llm_provider)
|
||||
)
|
||||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||||
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
|
||||
|
||||
if llm_provider == "openai" and llm_api_key:
|
||||
if self._is_llm_provider_supported(llm_provider) and llm_api_key:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=llm_api_key,
|
||||
base_url=llm_base_url,
|
||||
model=llm_model,
|
||||
knowledge_config=self._resolved_knowledge_config(),
|
||||
knowledge_searcher=self._knowledge_searcher,
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||
@@ -399,20 +605,22 @@ class DuplexPipeline:
|
||||
if tts_output_enabled:
|
||||
if not self.tts_service:
|
||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key
|
||||
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
|
||||
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
||||
self.tts_service = OpenAICompatibleTTSService(
|
||||
api_key=tts_api_key,
|
||||
api_url=tts_api_url,
|
||||
voice=tts_voice,
|
||||
model=tts_model,
|
||||
model=tts_model or "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)")
|
||||
logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=tts_voice,
|
||||
@@ -435,21 +643,23 @@ class DuplexPipeline:
|
||||
# Connect ASR service
|
||||
if not self.asr_service:
|
||||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
|
||||
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
|
||||
asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key
|
||||
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
|
||||
asr_model = self._runtime_asr.get("model") or settings.asr_model
|
||||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||
|
||||
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
|
||||
self.asr_service = OpenAICompatibleASRService(
|
||||
api_key=asr_api_key,
|
||||
model=asr_model,
|
||||
api_url=asr_api_url,
|
||||
model=asr_model or "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate=settings.sample_rate,
|
||||
interim_interval_ms=asr_interim_interval,
|
||||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||
on_transcript=self._on_transcript_callback
|
||||
)
|
||||
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)")
|
||||
logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})")
|
||||
else:
|
||||
self.asr_service = BufferedASRService(
|
||||
sample_rate=settings.sample_rate
|
||||
@@ -472,11 +682,13 @@ class DuplexPipeline:
|
||||
greeting_to_speak = generated_greeting
|
||||
self.conversation.greeting = generated_greeting
|
||||
if greeting_to_speak:
|
||||
self._start_turn()
|
||||
self._start_response()
|
||||
await self._send_event(
|
||||
ev(
|
||||
"assistant.response.final",
|
||||
text=greeting_to_speak,
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
),
|
||||
priority=20,
|
||||
)
|
||||
@@ -494,10 +706,58 @@ class DuplexPipeline:
|
||||
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
|
||||
|
||||
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
|
||||
await self._enqueue_outbound("event", event, priority)
|
||||
await self._enqueue_outbound("event", self._envelope_event(event), priority)
|
||||
|
||||
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
|
||||
await self._enqueue_outbound("audio", pcm_bytes, priority)
|
||||
if not pcm_bytes:
|
||||
return
|
||||
self._audio_out_frame_buffer += pcm_bytes
|
||||
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
|
||||
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
|
||||
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
|
||||
await self._enqueue_outbound("audio", frame, priority)
|
||||
|
||||
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
|
||||
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
|
||||
if not self._audio_out_frame_buffer:
|
||||
return
|
||||
tail = self._audio_out_frame_buffer
|
||||
self._audio_out_frame_buffer = b""
|
||||
if len(tail) < self._PCM_FRAME_BYTES:
|
||||
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
|
||||
await self._enqueue_outbound("audio", tail, priority)
|
||||
|
||||
async def _emit_transcript_delta(self, text: str) -> None:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"transcript.delta",
|
||||
trackId=self.track_audio_in,
|
||||
text=text,
|
||||
)
|
||||
},
|
||||
priority=30,
|
||||
)
|
||||
|
||||
async def _emit_llm_delta(self, text: str) -> None:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.track_audio_out,
|
||||
text=text,
|
||||
)
|
||||
},
|
||||
priority=20,
|
||||
)
|
||||
|
||||
async def _flush_pending_llm_delta(self) -> None:
|
||||
if not self._pending_llm_delta:
|
||||
return
|
||||
chunk = self._pending_llm_delta
|
||||
self._pending_llm_delta = ""
|
||||
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
|
||||
await self._emit_llm_delta(chunk)
|
||||
|
||||
async def _outbound_loop(self) -> None:
|
||||
"""Single sender loop that enforces priority for interrupt events."""
|
||||
@@ -546,13 +806,13 @@ class DuplexPipeline:
|
||||
|
||||
# Emit VAD event
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"trackId": self.track_audio_in,
|
||||
"probability": probability
|
||||
})
|
||||
await self._send_event(
|
||||
ev(
|
||||
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_in,
|
||||
probability=probability,
|
||||
),
|
||||
priority=30,
|
||||
@@ -661,6 +921,9 @@ class DuplexPipeline:
|
||||
# Cancel any current speaking
|
||||
await self._stop_current_speech()
|
||||
|
||||
self._start_turn()
|
||||
self._finalize_utterance()
|
||||
|
||||
# Start new turn
|
||||
await self.conversation.end_user_turn(text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||
@@ -683,24 +946,45 @@ class DuplexPipeline:
|
||||
if text == self._last_sent_transcript and not is_final:
|
||||
return
|
||||
|
||||
now_ms = time.monotonic() * 1000.0
|
||||
self._last_sent_transcript = text
|
||||
|
||||
# Send transcript event to client
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"transcript.final" if is_final else "transcript.delta",
|
||||
trackId=self.session_id,
|
||||
text=text,
|
||||
if is_final:
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = 0.0
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"transcript.final",
|
||||
trackId=self.track_audio_in,
|
||||
text=text,
|
||||
)
|
||||
},
|
||||
priority=30,
|
||||
)
|
||||
}, priority=30)
|
||||
logger.debug(f"Sent transcript (final): {text[:50]}...")
|
||||
return
|
||||
|
||||
self._pending_transcript_delta = text
|
||||
should_emit = (
|
||||
self._last_transcript_delta_emit_ms <= 0.0
|
||||
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
|
||||
)
|
||||
if should_emit and self._pending_transcript_delta:
|
||||
delta = self._pending_transcript_delta
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = now_ms
|
||||
await self._emit_transcript_delta(delta)
|
||||
|
||||
if not is_final:
|
||||
logger.info(f"[ASR] ASR interim: {text[:100]}")
|
||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||
logger.debug(f"Sent transcript (interim): {text[:50]}...")
|
||||
|
||||
async def _on_speech_start(self) -> None:
|
||||
"""Handle user starting to speak."""
|
||||
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
|
||||
self._start_turn()
|
||||
self._finalize_utterance()
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
@@ -779,6 +1063,7 @@ class DuplexPipeline:
|
||||
return
|
||||
|
||||
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
|
||||
self._finalize_utterance()
|
||||
|
||||
# For ASR backends that already emitted final via callback,
|
||||
# avoid duplicating transcript.final on EOU.
|
||||
@@ -786,7 +1071,7 @@ class DuplexPipeline:
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"transcript.final",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_in,
|
||||
text=user_text,
|
||||
)
|
||||
}, priority=25)
|
||||
@@ -794,6 +1079,8 @@ class DuplexPipeline:
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = 0.0
|
||||
self._asr_capture_active = False
|
||||
self._pending_speech_audio = b""
|
||||
|
||||
@@ -881,6 +1168,23 @@ class DuplexPipeline:
|
||||
result[name] = executor
|
||||
return result
|
||||
|
||||
def _resolved_tool_allowlist(self) -> List[str]:
|
||||
names: set[str] = set()
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
names.add(name)
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.add(str(fn.get("name")).strip())
|
||||
elif item.get("name"):
|
||||
names.add(str(item.get("name")).strip())
|
||||
return sorted([name for name in names if name])
|
||||
|
||||
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
|
||||
fn = tool_call.get("function")
|
||||
if isinstance(fn, dict):
|
||||
@@ -894,6 +1198,44 @@ class DuplexPipeline:
|
||||
# Default to server execution unless explicitly marked as client.
|
||||
return "server"
|
||||
|
||||
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
fn = tool_call.get("function")
|
||||
if not isinstance(fn, dict):
|
||||
return {}
|
||||
raw = fn.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return parsed if isinstance(parsed, dict) else {"raw": raw}
|
||||
except Exception:
|
||||
return {"raw": raw}
|
||||
return {}
|
||||
|
||||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||||
status_code = int(status.get("code") or 0) if status else 0
|
||||
status_message = str(status.get("message") or "") if status else ""
|
||||
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||||
tool_name = str(result.get("name") or "unknown_tool")
|
||||
ok = bool(200 <= status_code < 300)
|
||||
retryable = status_code >= 500 or status_code in {429, 408}
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
if not ok:
|
||||
error = {
|
||||
"code": status_code or 500,
|
||||
"message": status_message or "tool_execution_failed",
|
||||
"retryable": retryable,
|
||||
}
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"ok": ok,
|
||||
"error": error,
|
||||
"status": {"code": status_code, "message": status_message},
|
||||
}
|
||||
|
||||
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
|
||||
tool_name = str(result.get("name") or "unknown_tool")
|
||||
call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||||
@@ -904,12 +1246,17 @@ class DuplexPipeline:
|
||||
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
|
||||
f"status={status_code} {status_message}".strip()
|
||||
)
|
||||
normalized = self._normalize_tool_result(result)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_result",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
source=source,
|
||||
tool_call_id=normalized["tool_call_id"],
|
||||
tool_name=normalized["tool_name"],
|
||||
ok=normalized["ok"],
|
||||
error=normalized["error"],
|
||||
result=result,
|
||||
)
|
||||
},
|
||||
@@ -927,6 +1274,9 @@ class DuplexPipeline:
|
||||
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||||
if not call_id:
|
||||
continue
|
||||
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
|
||||
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
|
||||
continue
|
||||
if call_id in self._completed_tool_call_ids:
|
||||
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
|
||||
continue
|
||||
@@ -972,6 +1322,7 @@ class DuplexPipeline:
|
||||
}
|
||||
finally:
|
||||
self._pending_tool_waiters.pop(call_id, None)
|
||||
self._pending_client_tool_call_ids.discard(call_id)
|
||||
|
||||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||
if isinstance(item, LLMStreamEvent):
|
||||
@@ -998,6 +1349,11 @@ class DuplexPipeline:
|
||||
user_text: User's transcribed text
|
||||
"""
|
||||
try:
|
||||
if not self._current_turn_id:
|
||||
self._start_turn()
|
||||
if not self._current_utterance_id:
|
||||
self._finalize_utterance()
|
||||
self._start_response()
|
||||
# Start latency tracking
|
||||
self._turn_start_time = time.time()
|
||||
self._first_audio_sent = False
|
||||
@@ -1012,6 +1368,8 @@ class DuplexPipeline:
|
||||
self._drop_outbound_audio = False
|
||||
|
||||
first_audio_sent = False
|
||||
self._pending_llm_delta = ""
|
||||
self._last_llm_delta_emit_ms = 0.0
|
||||
for _ in range(max_rounds):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
@@ -1028,6 +1386,7 @@ class DuplexPipeline:
|
||||
|
||||
event = self._normalize_stream_event(raw_event)
|
||||
if event.type == "tool_call":
|
||||
await self._flush_pending_llm_delta()
|
||||
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||||
if not tool_call:
|
||||
continue
|
||||
@@ -1045,11 +1404,19 @@ class DuplexPipeline:
|
||||
f"executor={executor} args={args_preview}"
|
||||
)
|
||||
tool_calls.append(enriched_tool_call)
|
||||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||||
if executor == "client" and call_id:
|
||||
self._pending_client_tool_call_ids.add(call_id)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_call",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
tool_call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
executor=executor,
|
||||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||||
tool_call=enriched_tool_call,
|
||||
)
|
||||
},
|
||||
@@ -1071,19 +1438,13 @@ class DuplexPipeline:
|
||||
round_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
},
|
||||
# Keep delta/final on the same event priority so FIFO seq
|
||||
# preserves stream order (avoid late-delta after final).
|
||||
priority=20,
|
||||
)
|
||||
self._pending_llm_delta += text_chunk
|
||||
now_ms = time.monotonic() * 1000.0
|
||||
if (
|
||||
self._last_llm_delta_emit_ms <= 0.0
|
||||
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
|
||||
):
|
||||
await self._flush_pending_llm_delta()
|
||||
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
@@ -1112,11 +1473,12 @@ class DuplexPipeline:
|
||||
|
||||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
@@ -1130,6 +1492,7 @@ class DuplexPipeline:
|
||||
)
|
||||
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
await self._flush_pending_llm_delta()
|
||||
if (
|
||||
self._tts_output_enabled()
|
||||
and remaining_text
|
||||
@@ -1137,11 +1500,12 @@ class DuplexPipeline:
|
||||
and not self._interrupt_event.is_set()
|
||||
):
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
@@ -1172,7 +1536,7 @@ class DuplexPipeline:
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
execute_server_tool(call),
|
||||
self._server_tool_executor(call),
|
||||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -1204,11 +1568,12 @@ class DuplexPipeline:
|
||||
]
|
||||
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self._flush_pending_llm_delta()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
text=full_response,
|
||||
)
|
||||
},
|
||||
@@ -1217,10 +1582,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self._flush_audio_out_frames(priority=50)
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.end",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=10)
|
||||
|
||||
@@ -1241,6 +1607,8 @@ class DuplexPipeline:
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
self._current_response_id = None
|
||||
self._current_tts_id = None
|
||||
|
||||
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
|
||||
"""
|
||||
@@ -1277,7 +1645,7 @@ class DuplexPipeline:
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"metrics.ttfb",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
}, priority=25)
|
||||
@@ -1354,10 +1722,11 @@ class DuplexPipeline:
|
||||
first_audio_sent = False
|
||||
|
||||
# Send track start event
|
||||
self._start_tts()
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=10)
|
||||
|
||||
@@ -1379,7 +1748,7 @@ class DuplexPipeline:
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"metrics.ttfb",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
}, priority=25)
|
||||
@@ -1391,10 +1760,11 @@ class DuplexPipeline:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Send track end event
|
||||
await self._flush_audio_out_frames(priority=50)
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.end",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=10)
|
||||
|
||||
@@ -1422,13 +1792,14 @@ class DuplexPipeline:
|
||||
self._interrupt_event.set()
|
||||
self._is_bot_speaking = False
|
||||
self._drop_outbound_audio = True
|
||||
self._audio_out_frame_buffer = b""
|
||||
|
||||
# Send interrupt event to client IMMEDIATELY
|
||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"response.interrupted",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=0)
|
||||
|
||||
@@ -1455,6 +1826,7 @@ class DuplexPipeline:
|
||||
async def _stop_current_speech(self) -> None:
|
||||
"""Stop any current speech task."""
|
||||
self._drop_outbound_audio = True
|
||||
self._audio_out_frame_buffer = b""
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
self._interrupt_event.set()
|
||||
self._current_turn_task.cancel()
|
||||
|
||||
244
engine/core/history_bridge.py
Normal file
244
engine/core/history_bridge.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Async history bridge for non-blocking transcript persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HistoryTranscriptJob:
|
||||
call_id: str
|
||||
turn_index: int
|
||||
speaker: str
|
||||
content: str
|
||||
start_ms: int
|
||||
end_ms: int
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class SessionHistoryBridge:
|
||||
"""Session-scoped buffered history writer with background retries."""
|
||||
|
||||
_STOP_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
history_writer: Any,
|
||||
enabled: bool,
|
||||
queue_max_size: int,
|
||||
retry_max_attempts: int,
|
||||
retry_backoff_sec: float,
|
||||
finalize_drain_timeout_sec: float,
|
||||
):
|
||||
self._history_writer = history_writer
|
||||
self._enabled = bool(enabled and history_writer is not None)
|
||||
self._queue_max_size = max(1, int(queue_max_size))
|
||||
self._retry_max_attempts = max(0, int(retry_max_attempts))
|
||||
self._retry_backoff_sec = max(0.0, float(retry_backoff_sec))
|
||||
self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec))
|
||||
|
||||
self._call_id: Optional[str] = None
|
||||
self._turn_index: int = 0
|
||||
self._started_mono: Optional[float] = None
|
||||
self._finalized: bool = False
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._finalize_lock = asyncio.Lock()
|
||||
self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def call_id(self) -> Optional[str]:
|
||||
return self._call_id
|
||||
|
||||
async def start_call(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str,
|
||||
) -> Optional[str]:
|
||||
"""Create remote call record and start background worker."""
|
||||
if not self._enabled or self._call_id:
|
||||
return self._call_id
|
||||
|
||||
call_id = await self._history_writer.create_call_record(
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
source=source,
|
||||
)
|
||||
if not call_id:
|
||||
return None
|
||||
|
||||
self._call_id = str(call_id)
|
||||
self._turn_index = 0
|
||||
self._finalized = False
|
||||
self._started_mono = time.monotonic()
|
||||
self._ensure_worker()
|
||||
return self._call_id
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
if self._started_mono is None:
|
||||
return 0
|
||||
return max(0, int((time.monotonic() - self._started_mono) * 1000))
|
||||
|
||||
def enqueue_turn(self, *, role: str, text: str) -> bool:
|
||||
"""Queue one transcript write without blocking the caller."""
|
||||
if not self._enabled or not self._call_id or self._finalized:
|
||||
return False
|
||||
|
||||
content = str(text or "").strip()
|
||||
if not content:
|
||||
return False
|
||||
|
||||
speaker = "human" if str(role or "").strip().lower() == "user" else "ai"
|
||||
end_ms = self.elapsed_ms()
|
||||
estimated_duration_ms = max(300, min(12000, len(content) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
job = _HistoryTranscriptJob(
|
||||
call_id=self._call_id,
|
||||
turn_index=self._turn_index,
|
||||
speaker=speaker,
|
||||
content=content,
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._turn_index += 1
|
||||
self._ensure_worker()
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"History queue full; dropping transcript call_id={} turn={}",
|
||||
self._call_id,
|
||||
job.turn_index,
|
||||
)
|
||||
return False
|
||||
|
||||
async def finalize(self, *, status: str) -> bool:
|
||||
"""Finalize history record once; waits briefly for queue drain."""
|
||||
if not self._enabled or not self._call_id:
|
||||
return False
|
||||
|
||||
async with self._finalize_lock:
|
||||
if self._finalized:
|
||||
return True
|
||||
|
||||
await self._drain_queue()
|
||||
ok = await self._history_writer.finalize_call_record(
|
||||
call_id=self._call_id,
|
||||
status=status,
|
||||
duration_seconds=self.duration_seconds(),
|
||||
)
|
||||
if ok:
|
||||
self._finalized = True
|
||||
await self._stop_worker()
|
||||
return ok
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop worker task and release queue resources."""
|
||||
await self._stop_worker()
|
||||
|
||||
def duration_seconds(self) -> int:
|
||||
if self._started_mono is None:
|
||||
return 0
|
||||
return max(0, int(time.monotonic() - self._started_mono))
|
||||
|
||||
def _ensure_worker(self) -> None:
|
||||
if self._worker_task and not self._worker_task.done():
|
||||
return
|
||||
self._worker_task = asyncio.create_task(self._worker_loop())
|
||||
|
||||
async def _drain_queue(self) -> None:
|
||||
if self._finalize_drain_timeout_sec <= 0:
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec)
|
||||
|
||||
async def _stop_worker(self) -> None:
|
||||
task = self._worker_task
|
||||
if not task:
|
||||
return
|
||||
if task.done():
|
||||
self._worker_task = None
|
||||
return
|
||||
|
||||
sent = False
|
||||
try:
|
||||
self._queue.put_nowait(self._STOP_SENTINEL)
|
||||
sent = True
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
if not sent:
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=1.5)
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._worker_task = None
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
try:
|
||||
if item is self._STOP_SENTINEL:
|
||||
return
|
||||
|
||||
assert isinstance(item, _HistoryTranscriptJob)
|
||||
await self._write_with_retry(item)
|
||||
except Exception as exc:
|
||||
logger.warning("History worker write failed unexpectedly: {}", exc)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool:
|
||||
for attempt in range(self._retry_max_attempts + 1):
|
||||
ok = await self._history_writer.add_transcript(
|
||||
call_id=job.call_id,
|
||||
turn_index=job.turn_index,
|
||||
speaker=job.speaker,
|
||||
content=job.content,
|
||||
start_ms=job.start_ms,
|
||||
end_ms=job.end_ms,
|
||||
duration_ms=job.duration_ms,
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
|
||||
if attempt >= self._retry_max_attempts:
|
||||
logger.warning(
|
||||
"History write dropped after retries call_id={} turn={}",
|
||||
job.call_id,
|
||||
job.turn_index,
|
||||
)
|
||||
return False
|
||||
|
||||
if self._retry_backoff_sec > 0:
|
||||
await asyncio.sleep(self._retry_backoff_sec * (2**attempt))
|
||||
return False
|
||||
17
engine/core/ports/__init__.py
Normal file
17
engine/core/ports/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Port interfaces for engine-side integration boundaries."""
|
||||
|
||||
from core.ports.backend import (
|
||||
AssistantConfigProvider,
|
||||
BackendGateway,
|
||||
HistoryWriter,
|
||||
KnowledgeSearcher,
|
||||
ToolResourceResolver,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssistantConfigProvider",
|
||||
"BackendGateway",
|
||||
"HistoryWriter",
|
||||
"KnowledgeSearcher",
|
||||
"ToolResourceResolver",
|
||||
]
|
||||
84
engine/core/ports/backend.py
Normal file
84
engine/core/ports/backend.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Backend integration ports.
|
||||
|
||||
These interfaces define the boundary between engine runtime logic and
|
||||
backend-side capabilities (config lookup, history persistence, retrieval,
|
||||
and tool resource discovery).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
|
||||
class AssistantConfigProvider(Protocol):
|
||||
"""Port for loading trusted assistant runtime configuration."""
|
||||
|
||||
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch assistant configuration payload."""
|
||||
|
||||
|
||||
class HistoryWriter(Protocol):
|
||||
"""Port for persisting call and transcript history."""
|
||||
|
||||
async def create_call_record(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
"""Create a call record and return backend call ID."""
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Append one transcript turn segment."""
|
||||
|
||||
async def finalize_call_record(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
"""Finalize a call record."""
|
||||
|
||||
|
||||
class KnowledgeSearcher(Protocol):
|
||||
"""Port for RAG / knowledge retrieval operations."""
|
||||
|
||||
async def search_knowledge_context(
|
||||
self,
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search a knowledge source and return ranked snippets."""
|
||||
|
||||
|
||||
class ToolResourceResolver(Protocol):
|
||||
"""Port for resolving tool metadata/configuration."""
|
||||
|
||||
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch tool resource configuration."""
|
||||
|
||||
|
||||
class BackendGateway(
|
||||
AssistantConfigProvider,
|
||||
HistoryWriter,
|
||||
KnowledgeSearcher,
|
||||
ToolResourceResolver,
|
||||
Protocol,
|
||||
):
|
||||
"""Composite backend gateway interface used by engine services."""
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import (
|
||||
create_history_call_record,
|
||||
add_history_transcript,
|
||||
finalize_history_call_record,
|
||||
)
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from core.conversation import ConversationTurn
|
||||
from core.history_bridge import SessionHistoryBridge
|
||||
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
||||
from app.config import settings
|
||||
from services.base import LLMMessage
|
||||
@@ -49,7 +46,39 @@ class Session:
|
||||
Uses full duplex voice conversation pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
_CLIENT_METADATA_OVERRIDES = {
|
||||
"firstTurnMode",
|
||||
"greeting",
|
||||
"generatedOpenerEnabled",
|
||||
"systemPrompt",
|
||||
"output",
|
||||
"bargeIn",
|
||||
"knowledge",
|
||||
"knowledgeBaseId",
|
||||
"history",
|
||||
"userId",
|
||||
"assistantId",
|
||||
"source",
|
||||
}
|
||||
_CLIENT_METADATA_ID_KEYS = {
|
||||
"appId",
|
||||
"app_id",
|
||||
"channel",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
transport: BaseTransport,
|
||||
use_duplex: bool = None,
|
||||
backend_gateway: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
@@ -61,12 +90,23 @@ class Session:
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
|
||||
self._history_bridge = SessionHistoryBridge(
|
||||
history_writer=self._backend_gateway,
|
||||
enabled=settings.history_enabled,
|
||||
queue_max_size=settings.history_queue_max_size,
|
||||
retry_max_attempts=settings.history_retry_max_attempts,
|
||||
retry_backoff_sec=settings.history_retry_backoff_sec,
|
||||
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
|
||||
)
|
||||
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
greeting=settings.duplex_greeting,
|
||||
knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None),
|
||||
tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None),
|
||||
)
|
||||
|
||||
# Session state
|
||||
@@ -78,17 +118,15 @@ class Session:
|
||||
self.authenticated: bool = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
self._history_call_id: Optional[str] = None
|
||||
self._history_turn_index: int = 0
|
||||
self._history_call_started_mono: Optional[float] = None
|
||||
self._history_finalized: bool = False
|
||||
self.current_track_id: str = self.TRACK_CONTROL
|
||||
self._event_seq: int = 0
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self._cleaned_up = False
|
||||
self.workflow_runner: Optional[WorkflowRunner] = None
|
||||
self._workflow_last_user_text: str = ""
|
||||
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
||||
|
||||
self.pipeline.set_event_sequence_provider(self._next_event_seq)
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
@@ -129,13 +167,47 @@ class Session:
|
||||
"client",
|
||||
"Audio received before session.start",
|
||||
"protocol.order",
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
if not audio_bytes:
|
||||
return
|
||||
if len(audio_bytes) % 2 != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Invalid PCM payload: odd number of bytes",
|
||||
"audio.invalid_pcm",
|
||||
stage="audio",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
frame_bytes = self.AUDIO_FRAME_BYTES
|
||||
if len(audio_bytes) % frame_bytes != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
|
||||
"audio.frame_size_mismatch",
|
||||
stage="audio",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
for i in range(0, len(audio_bytes), frame_bytes):
|
||||
frame = audio_bytes[i : i + frame_bytes]
|
||||
await self.pipeline.process_audio(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
await self._send_error(
|
||||
"server",
|
||||
f"Audio processing failed: {e}",
|
||||
"audio.processing_failed",
|
||||
stage="audio",
|
||||
retryable=True,
|
||||
)
|
||||
|
||||
async def _handle_v1_message(self, message: Any) -> None:
|
||||
"""Route validated WS v1 message to handlers."""
|
||||
@@ -176,7 +248,7 @@ class Session:
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
@@ -198,9 +270,9 @@ class Session:
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
auth_payload = message.auth
|
||||
api_key = auth_payload.apiKey if auth_payload else None
|
||||
jwt = auth_payload.jwt if auth_payload else None
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
@@ -217,10 +289,9 @@ class Session:
|
||||
self.authenticated = True
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
sessionId=self.id,
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
@@ -231,8 +302,12 @@ class Session:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
metadata = message.metadata or {}
|
||||
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
||||
raw_metadata = message.metadata or {}
|
||||
workflow_runtime = self._bootstrap_workflow(raw_metadata)
|
||||
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
|
||||
client_runtime = self._sanitize_client_metadata(raw_metadata)
|
||||
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
|
||||
metadata = self._merge_runtime_metadata(metadata, client_runtime)
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
await self._start_history_bridge(metadata)
|
||||
@@ -248,28 +323,37 @@ class Session:
|
||||
|
||||
self.state = "accepted"
|
||||
self.ws_state = WsSessionState.ACTIVE
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"session.started",
|
||||
sessionId=self.id,
|
||||
trackId=self.current_track_id,
|
||||
audio=message.audio or {},
|
||||
tracks={
|
||||
"audio_in": self.TRACK_AUDIO_IN,
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
audio=message.audio.model_dump() if message.audio else {},
|
||||
)
|
||||
)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"config.resolved",
|
||||
trackId=self.TRACK_CONTROL,
|
||||
config=self._build_config_resolved(metadata),
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
@@ -285,17 +369,24 @@ class Session:
|
||||
stop_reason = reason or "client_requested"
|
||||
self.state = "hungup"
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"session.stopped",
|
||||
sessionId=self.id,
|
||||
reason=stop_reason,
|
||||
)
|
||||
)
|
||||
await self._finalize_history(status="connected")
|
||||
await self.transport.close()
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||
async def _send_error(
|
||||
self,
|
||||
sender: str,
|
||||
error_message: str,
|
||||
code: str,
|
||||
stage: Optional[str] = None,
|
||||
retryable: Optional[bool] = None,
|
||||
track_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
@@ -304,13 +395,26 @@ class Session:
|
||||
error_message: Error message
|
||||
code: Machine-readable error code
|
||||
"""
|
||||
await self.transport.send_event(
|
||||
resolved_stage = stage or self._infer_error_stage(code)
|
||||
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
|
||||
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"error",
|
||||
sender=sender,
|
||||
code=code,
|
||||
message=error_message,
|
||||
trackId=self.current_track_id,
|
||||
stage=resolved_stage,
|
||||
retryable=resolved_retryable,
|
||||
trackId=resolved_track_id,
|
||||
data={
|
||||
"error": {
|
||||
"stage": resolved_stage,
|
||||
"code": code,
|
||||
"message": error_message,
|
||||
"retryable": resolved_retryable,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -329,11 +433,12 @@ class Session:
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self._finalize_history(status="connected")
|
||||
await self.pipeline.cleanup()
|
||||
await self._history_bridge.shutdown()
|
||||
await self.transport.close()
|
||||
|
||||
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
|
||||
"""Initialize backend history call record for this session."""
|
||||
if self._history_call_id:
|
||||
if self._history_bridge.call_id:
|
||||
return
|
||||
|
||||
history_meta: Dict[str, Any] = {}
|
||||
@@ -349,7 +454,7 @@ class Session:
|
||||
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
|
||||
source = str(history_meta.get("source", metadata.get("source", "debug")))
|
||||
|
||||
call_id = await create_history_call_record(
|
||||
call_id = await self._history_bridge.start_call(
|
||||
user_id=user_id,
|
||||
assistant_id=str(assistant_id) if assistant_id else None,
|
||||
source=source,
|
||||
@@ -357,10 +462,6 @@ class Session:
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
self._history_call_id = call_id
|
||||
self._history_call_started_mono = time.monotonic()
|
||||
self._history_turn_index = 0
|
||||
self._history_finalized = False
|
||||
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
||||
|
||||
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
||||
@@ -372,48 +473,11 @@ class Session:
|
||||
elif role == "assistant":
|
||||
await self._maybe_advance_workflow(turn.text.strip())
|
||||
|
||||
if not self._history_call_id:
|
||||
return
|
||||
if not turn.text or not turn.text.strip():
|
||||
return
|
||||
|
||||
role = (turn.role or "").lower()
|
||||
speaker = "human" if role == "user" else "ai"
|
||||
|
||||
end_ms = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
|
||||
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
turn_index = self._history_turn_index
|
||||
await add_history_transcript(
|
||||
call_id=self._history_call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=turn.text.strip(),
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._history_turn_index += 1
|
||||
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
|
||||
|
||||
async def _finalize_history(self, status: str) -> None:
|
||||
"""Finalize history call record once."""
|
||||
if not self._history_call_id or self._history_finalized:
|
||||
return
|
||||
|
||||
duration_seconds = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
|
||||
|
||||
ok = await finalize_history_call_record(
|
||||
call_id=self._history_call_id,
|
||||
status=status,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
if ok:
|
||||
self._history_finalized = True
|
||||
await self._history_bridge.finalize(status=status)
|
||||
|
||||
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse workflow payload and return initial runtime overrides."""
|
||||
@@ -483,10 +547,9 @@ class Session:
|
||||
node = transition.node
|
||||
edge = transition.edge
|
||||
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.edge.taken",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
edgeId=edge.id,
|
||||
fromNodeId=edge.from_node_id,
|
||||
@@ -494,10 +557,9 @@ class Session:
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
nodeName=node.name,
|
||||
@@ -510,10 +572,9 @@ class Session:
|
||||
self.pipeline.apply_runtime_overrides(node_runtime)
|
||||
|
||||
if node.node_type == "tool":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.tool.requested",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
tool=node.tool or {},
|
||||
@@ -522,10 +583,9 @@ class Session:
|
||||
return
|
||||
|
||||
if node.node_type == "human_transfer":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.human_transfer",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
@@ -534,16 +594,77 @@ class Session:
|
||||
return
|
||||
|
||||
if node.node_type == "end":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.ended",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_end")
|
||||
|
||||
def _next_event_seq(self) -> int:
|
||||
self._event_seq += 1
|
||||
return self._event_seq
|
||||
|
||||
def _event_source(self, event_type: str) -> str:
|
||||
if event_type.startswith("workflow."):
|
||||
return "system"
|
||||
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
|
||||
return "system"
|
||||
if event_type == "error":
|
||||
return "system"
|
||||
return "system"
|
||||
|
||||
def _infer_error_stage(self, code: str) -> str:
|
||||
normalized = str(code or "").strip().lower()
|
||||
if normalized.startswith("audio."):
|
||||
return "audio"
|
||||
if normalized.startswith("tool."):
|
||||
return "tool"
|
||||
if normalized.startswith("asr."):
|
||||
return "asr"
|
||||
if normalized.startswith("llm."):
|
||||
return "llm"
|
||||
if normalized.startswith("tts."):
|
||||
return "tts"
|
||||
return "protocol"
|
||||
|
||||
def _error_track_id(self, stage: str, code: str) -> str:
|
||||
if stage in {"audio", "asr"}:
|
||||
return self.TRACK_AUDIO_IN
|
||||
if stage in {"llm", "tts", "tool"}:
|
||||
return self.TRACK_AUDIO_OUT
|
||||
if str(code or "").strip().lower().startswith("auth."):
|
||||
return self.TRACK_CONTROL
|
||||
return self.TRACK_CONTROL
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
track_id = event.get("trackId") or self.TRACK_CONTROL
|
||||
|
||||
data = event.get("data")
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
for k, v in event.items():
|
||||
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
|
||||
continue
|
||||
data.setdefault(k, v)
|
||||
|
||||
event["sessionId"] = self.id
|
||||
event["seq"] = self._next_event_seq()
|
||||
event["source"] = source
|
||||
event["trackId"] = track_id
|
||||
event["data"] = data
|
||||
return event
|
||||
|
||||
async def _send_event(self, event: Dict[str, Any]) -> None:
|
||||
await self.transport.send_event(self._envelope_event(event))
|
||||
|
||||
async def send_heartbeat(self) -> None:
|
||||
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
|
||||
|
||||
async def _workflow_llm_route(
|
||||
self,
|
||||
node: WorkflowNodeDef,
|
||||
@@ -629,6 +750,137 @@ class Session:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
async def _load_server_runtime_metadata(
|
||||
self,
|
||||
client_metadata: Dict[str, Any],
|
||||
workflow_runtime: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Load trusted runtime metadata from backend assistant config."""
|
||||
assistant_id = (
|
||||
workflow_runtime.get("assistantId")
|
||||
or client_metadata.get("assistantId")
|
||||
or client_metadata.get("appId")
|
||||
or client_metadata.get("app_id")
|
||||
)
|
||||
if assistant_id is None:
|
||||
return {}
|
||||
|
||||
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
|
||||
if not callable(provider):
|
||||
return {}
|
||||
|
||||
payload = await provider(str(assistant_id).strip())
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
|
||||
assistant_cfg: Dict[str, Any] = {}
|
||||
session_start_cfg = payload.get("sessionStartMetadata")
|
||||
if isinstance(session_start_cfg, dict):
|
||||
assistant_cfg.update(session_start_cfg)
|
||||
if isinstance(payload.get("assistant"), dict):
|
||||
assistant_cfg.update(payload.get("assistant"))
|
||||
elif not assistant_cfg:
|
||||
assistant_cfg = payload
|
||||
|
||||
if not isinstance(assistant_cfg, dict):
|
||||
return {}
|
||||
|
||||
runtime: Dict[str, Any] = {}
|
||||
passthrough_keys = {
|
||||
"firstTurnMode",
|
||||
"generatedOpenerEnabled",
|
||||
"output",
|
||||
"bargeIn",
|
||||
"knowledgeBaseId",
|
||||
"knowledge",
|
||||
"history",
|
||||
"userId",
|
||||
"source",
|
||||
"tools",
|
||||
"services",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
}
|
||||
for key in passthrough_keys:
|
||||
if key in assistant_cfg:
|
||||
runtime[key] = assistant_cfg[key]
|
||||
|
||||
if assistant_cfg.get("systemPrompt") is not None:
|
||||
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
|
||||
elif assistant_cfg.get("prompt") is not None:
|
||||
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
|
||||
|
||||
if assistant_cfg.get("greeting") is not None:
|
||||
runtime["greeting"] = assistant_cfg.get("greeting")
|
||||
elif assistant_cfg.get("opener") is not None:
|
||||
runtime["greeting"] = assistant_cfg.get("opener")
|
||||
|
||||
resolved_assistant_id = (
|
||||
assistant_cfg.get("assistantId")
|
||||
or payload.get("assistantId")
|
||||
or assistant_id
|
||||
)
|
||||
runtime["assistantId"] = str(resolved_assistant_id)
|
||||
|
||||
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
|
||||
runtime["configVersionId"] = payload.get("configVersionId")
|
||||
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
|
||||
runtime["configVersionId"] = payload.get("config_version_id")
|
||||
|
||||
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
|
||||
runtime["configVersionId"] = runtime.get("config_version_id")
|
||||
return runtime
|
||||
|
||||
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitize untrusted metadata sources.
|
||||
|
||||
This keeps only a small override whitelist and stable config ID fields.
|
||||
"""
|
||||
if not isinstance(metadata, dict):
|
||||
return {}
|
||||
|
||||
sanitized: Dict[str, Any] = {}
|
||||
for key in self._CLIENT_METADATA_ID_KEYS:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
for key in self._CLIENT_METADATA_OVERRIDES:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
|
||||
return sanitized
|
||||
|
||||
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply client metadata whitelist and remove forbidden secrets."""
|
||||
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
|
||||
if isinstance(metadata.get("services"), dict):
|
||||
logger.warning(
|
||||
"Session {} provided metadata.services from client; client-side service config is ignored",
|
||||
self.id,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build public resolved config payload (secrets removed)."""
|
||||
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
|
||||
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
|
||||
runtime = self.pipeline.resolved_runtime_config()
|
||||
|
||||
return {
|
||||
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
|
||||
"channel": metadata.get("channel"),
|
||||
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||
"prompt": {"sha256": prompt_hash},
|
||||
"output": runtime.get("output", {}),
|
||||
"services": runtime.get("services", {}),
|
||||
"tools": runtime.get("tools", {}),
|
||||
"tracks": {
|
||||
"audio_in": self.TRACK_AUDIO_IN,
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
}
|
||||
|
||||
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort extraction of a JSON object from freeform text."""
|
||||
try:
|
||||
|
||||
@@ -4,11 +4,13 @@ import asyncio
|
||||
import ast
|
||||
import operator
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.backend_client import fetch_tool_resource
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
|
||||
ToolResourceFetcher = Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||||
|
||||
_BIN_OPS = {
|
||||
ast.Add: operator.add,
|
||||
@@ -170,11 +172,21 @@ def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Default tool resource resolver via backend adapter."""
|
||||
adapter = build_backend_adapter_from_settings()
|
||||
return await adapter.fetch_tool_resource(tool_id)
|
||||
|
||||
|
||||
async def execute_server_tool(
|
||||
tool_call: Dict[str, Any],
|
||||
tool_resource_fetcher: Optional[ToolResourceFetcher] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a server-side tool and return normalized result payload."""
|
||||
call_id = str(tool_call.get("id") or "").strip()
|
||||
tool_name = _extract_tool_name(tool_call)
|
||||
args = _extract_tool_args(tool_call)
|
||||
resource_fetcher = tool_resource_fetcher or fetch_tool_resource
|
||||
|
||||
if tool_name == "calculator":
|
||||
expression = str(args.get("expression") or "").strip()
|
||||
@@ -257,7 +269,7 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
resource = await fetch_tool_resource(tool_name)
|
||||
resource = await resource_fetcher(tool_name)
|
||||
if resource and str(resource.get("category") or "") == "query":
|
||||
method = str(resource.get("http_method") or "GET").strip().upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
|
||||
Reference in New Issue
Block a user