From 72ed7d051285adb4456e923ddb4d38e904382200 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Thu, 26 Feb 2026 01:58:39 +0800 Subject: [PATCH] Unify db api --- .gitignore | 6 + api/.gitignore | 8 +- api/app/routers/assistants.py | 68 ++- api/app/schemas.py | 29 ++ api/tests/test_assistants.py | 23 + engine/.env.example | 71 +--- engine/.gitignore | 2 + engine/README.md | 53 ++- engine/app/backend_adapters.py | 357 ++++++++++++++++ engine/app/backend_client.py | 190 ++------- engine/app/config.py | 404 +++++++++++++++++- engine/app/main.py | 17 +- engine/config/agents/example.yaml | 50 +++ engine/config/agents/tools.yaml | 73 ++++ engine/core/duplex_pipeline.py | 482 ++++++++++++++++++--- engine/core/history_bridge.py | 244 +++++++++++ engine/core/ports/__init__.py | 17 + engine/core/ports/backend.py | 84 ++++ engine/core/session.py | 432 +++++++++++++++---- engine/core/tool_executor.py | 20 +- engine/docs/backend_integration.md | 47 ++ engine/docs/ws_v1_schema.md | 115 +++-- engine/docs/ws_v1_schema_zh.md | 520 +++++++++++++++++++++++ engine/examples/mic_client.py | 131 ++++-- engine/examples/simple_client.py | 92 +++- engine/examples/test_websocket.py | 31 +- engine/examples/wav_client.py | 189 +++++--- engine/examples/web_client.html | 25 +- engine/models/ws_v1.py | 64 ++- engine/requirements.txt | 1 + engine/services/llm.py | 18 +- engine/services/openai_compatible_asr.py | 12 +- engine/services/openai_compatible_tts.py | 10 +- engine/tests/test_agent_config.py | 252 +++++++++++ engine/tests/test_backend_adapters.py | 150 +++++++ engine/tests/test_history_bridge.py | 147 +++++++ engine/tests/test_tool_call_flow.py | 32 +- web/services/apiClient.ts | 1 + web/services/backendApi.ts | 47 +- web/tsconfig.json | 5 +- 40 files changed, 3926 insertions(+), 593 deletions(-) create mode 100644 .gitignore create mode 100644 engine/app/backend_adapters.py create mode 100644 engine/config/agents/example.yaml create mode 100644 engine/config/agents/tools.yaml create mode 100644 engine/core/history_bridge.py create mode 100644 engine/core/ports/__init__.py create mode 100644 engine/core/ports/backend.py create mode 100644 engine/docs/backend_integration.md create mode 100644 engine/docs/ws_v1_schema_zh.md create mode 100644 engine/tests/test_agent_config.py create mode 100644 engine/tests/test_backend_adapters.py create mode 100644 engine/tests/test_history_bridge.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9bcc58 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# OS artifacts +.DS_Store +Thumbs.db + +# Workspace runtime data +data/ diff --git a/api/.gitignore b/api/.gitignore index 8d12426..383a32a 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -36,8 +36,12 @@ env/ *.sqlite *.sqlite3 -# Vector store data -data/vector_store/ +# Runtime data (SQLite, vector store, uploads, generated artifacts) +data/** +!data/ +!data/.gitkeep +!data/vector_store/ +data/vector_store/** !data/vector_store/.gitkeep # IDE diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index 815e9f6..e6cdcea 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -1,13 +1,13 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from typing import Optional +from typing import Any, Dict, List, Optional import uuid from datetime import datetime from ..db import get_db from ..models import Assistant, LLMModel, ASRModel, Voice from ..schemas import ( - AssistantCreate, AssistantUpdate, AssistantOut + AssistantCreate, AssistantUpdate, AssistantOut, AssistantEngineConfigResponse ) router = APIRouter(prefix="/assistants", tags=["Assistants"]) @@ -52,8 +52,13 @@ def _normalize_openai_compatible_voice_key(voice_value: str, model: str) -> str: return f"{model_name}:{voice_id}" -def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: - metadata = { +def _config_version_id(assistant: Assistant) -> str: + updated = assistant.updated_at or assistant.created_at or datetime.utcnow() + return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}" + + +def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]: + metadata: Dict[str, Any] = { "systemPrompt": assistant.prompt or "", "firstTurnMode": assistant.first_turn_mode or "bot_first", "greeting": assistant.opener or "", @@ -64,10 +69,29 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: "minDurationMs": int(assistant.interruption_sensitivity or 500), }, "services": {}, + "tools": assistant.tools or [], + "history": { + "assistantId": assistant.id, + "userId": int(assistant.user_id or 1), + "source": "debug", + }, } - warnings = [] + warnings: List[str] = [] - if assistant.llm_model_id: + config_mode = str(assistant.config_mode or "platform").strip().lower() + + if config_mode in {"dify", "fastgpt"}: + metadata["services"]["llm"] = { + "provider": "openai", + "model": "", + "apiKey": assistant.api_key, + "baseUrl": assistant.api_url, + } + if not (assistant.api_url or "").strip(): + warnings.append(f"External LLM API URL is empty for mode: {assistant.config_mode}") + if not (assistant.api_key or "").strip(): + warnings.append(f"External LLM API key is empty for mode: {assistant.config_mode}") + elif assistant.llm_model_id: llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first() if llm: metadata["services"]["llm"] = { @@ -87,6 +111,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: "provider": asr_provider, "model": asr.model_name or asr.name, "apiKey": asr.api_key if asr_provider == "openai_compatible" else None, + "baseUrl": asr.base_url if asr_provider == "openai_compatible" else None, } else: warnings.append(f"ASR model not found: {assistant.asr_model_id}") @@ -107,6 +132,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: "provider": tts_provider, "model": model, "apiKey": voice.api_key if tts_provider == "openai_compatible" else None, + "baseUrl": voice.base_url if tts_provider == "openai_compatible" else None, "voice": runtime_voice, "speed": assistant.speed or voice.speed, } @@ -126,10 +152,21 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: "kbId": assistant.knowledge_base_id, "nResults": 5, } + return metadata, warnings + + +def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[str, Any]: + session_metadata, warnings = _resolve_runtime_metadata(db, assistant) + config_version_id = _config_version_id(assistant) + assistant_cfg = dict(session_metadata) + assistant_cfg["assistantId"] = assistant.id + assistant_cfg["configVersionId"] = config_version_id return { "assistantId": assistant.id, - "sessionStartMetadata": metadata, + "configVersionId": config_version_id, + "assistant": assistant_cfg, + "sessionStartMetadata": session_metadata, "sources": { "llmModelId": assistant.llm_model_id, "asrModelId": assistant.asr_model_id, @@ -219,13 +256,22 @@ def get_assistant(id: str, db: Session = Depends(get_db)): return assistant_to_dict(assistant) -@router.get("/{id}/runtime-config") -def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)): - """Resolve assistant runtime config for engine WS session.start metadata.""" +@router.get("/{id}/config", response_model=AssistantEngineConfigResponse) +def get_assistant_config(id: str, db: Session = Depends(get_db)): + """Canonical engine config endpoint consumed by engine backend adapter.""" assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") - return _resolve_runtime_metadata(db, assistant) + return _build_engine_assistant_config(db, assistant) + + +@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse) +def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)): + """Legacy alias for resolved engine runtime config.""" + assistant = db.query(Assistant).filter(Assistant.id == id).first() + if not assistant: + raise HTTPException(status_code=404, detail="Assistant not found") + return _build_engine_assistant_config(db, assistant) @router.post("", response_model=AssistantOut) diff --git a/api/app/schemas.py b/api/app/schemas.py index a2a486e..8a69287 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -333,6 +333,35 @@ class AssistantOut(AssistantBase): from_attributes = True +class AssistantRuntimeMetadata(BaseModel): + """Canonical runtime metadata payload consumed by engine session.start.""" + + model_config = ConfigDict(extra="allow") + + systemPrompt: str = "" + firstTurnMode: str = "bot_first" + greeting: str = "" + generatedOpenerEnabled: bool = False + output: Dict[str, Any] = Field(default_factory=dict) + bargeIn: Dict[str, Any] = Field(default_factory=dict) + services: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + tools: List[Any] = Field(default_factory=list) + knowledgeBaseId: Optional[str] = None + knowledge: Dict[str, Any] = Field(default_factory=dict) + history: Dict[str, Any] = Field(default_factory=dict) + assistantId: Optional[str] = None + configVersionId: Optional[str] = None + + +class AssistantEngineConfigResponse(BaseModel): + assistantId: str + configVersionId: Optional[str] = None + assistant: AssistantRuntimeMetadata + sessionStartMetadata: AssistantRuntimeMetadata + sources: Dict[str, Optional[str]] = Field(default_factory=dict) + warnings: List[str] = Field(default_factory=list) + + class AssistantStats(BaseModel): assistant_id: str total_calls: int = 0 diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index eaab617..10cf93c 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -183,12 +183,16 @@ class TestAssistantAPI: def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data): """Test resolved runtime config endpoint for WS session.start metadata.""" + sample_asr_model_data["vendor"] = "OpenAI Compatible" llm_resp = client.post("/api/llm", json=sample_llm_model_data) assert llm_resp.status_code == 200 asr_resp = client.post("/api/asr", json=sample_asr_model_data) assert asr_resp.status_code == 200 + sample_voice_data["vendor"] = "OpenAI Compatible" + sample_voice_data["base_url"] = "https://tts.example.com/v1/audio/speech" + sample_voice_data["api_key"] = "test-voice-key" voice_resp = client.post("/api/voices", json=sample_voice_data) assert voice_resp.status_code == 200 voice_id = voice_resp.json()["id"] @@ -215,7 +219,26 @@ class TestAssistantAPI: assert metadata["greeting"] == "runtime opener" assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"] assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"] + assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"] assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"] + assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"] + + def test_get_engine_config_endpoint(self, client, sample_assistant_data): + """Test canonical assistant config endpoint consumed by engine backend adapter.""" + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + config_resp = client.get(f"/api/assistants/{assistant_id}/config") + assert config_resp.status_code == 200 + payload = config_resp.json() + + assert payload["assistantId"] == assistant_id + assert payload["assistant"]["assistantId"] == assistant_id + assert payload["assistant"]["configVersionId"].startswith(f"asst_{assistant_id}_") + assert payload["assistant"]["systemPrompt"] == sample_assistant_data["prompt"] + assert payload["sessionStartMetadata"]["systemPrompt"] == sample_assistant_data["prompt"] + assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data): sample_assistant_data["voiceOutputEnabled"] = False diff --git a/engine/.env.example b/engine/.env.example index f62a4c6..0dd0cd0 100644 --- a/engine/.env.example +++ b/engine/.env.example @@ -11,9 +11,16 @@ PORT=8000 # EXTERNAL_IP=1.2.3.4 # Backend bridge (optional) +# BACKEND_MODE=auto|http|disabled +BACKEND_MODE=auto BACKEND_URL=http://127.0.0.1:8100 BACKEND_TIMEOUT_SEC=10 +HISTORY_ENABLED=true HISTORY_DEFAULT_USER_ID=1 +HISTORY_QUEUE_MAX_SIZE=256 +HISTORY_RETRY_MAX_ATTEMPTS=2 +HISTORY_RETRY_BACKOFF_SEC=0.2 +HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC=1.5 # Audio SAMPLE_RATE=16000 @@ -23,57 +30,21 @@ CHUNK_SIZE_MS=20 DEFAULT_CODEC=pcm MAX_AUDIO_BUFFER_SECONDS=30 -# VAD / EOU -VAD_TYPE=silero -VAD_MODEL_PATH=data/vad/silero_vad.onnx -# Higher = stricter speech detection (fewer false positives, more misses). -VAD_THRESHOLD=0.5 -# Require this much continuous speech before utterance can be valid. -VAD_MIN_SPEECH_DURATION_MS=100 -# Silence duration required to finalize one user turn. -VAD_EOU_THRESHOLD_MS=800 +# Agent profile selection (optional fallback when CLI args are not used) +# Prefer CLI: +# python -m app.main --agent-config config/agents/default.yaml +# python -m app.main --agent-profile default +# AGENT_CONFIG_PATH=config/agents/default.yaml +# AGENT_PROFILE=default +AGENT_CONFIG_DIR=config/agents -# LLM -OPENAI_API_KEY=your_openai_api_key_here -# Optional for OpenAI-compatible providers. -# OPENAI_API_URL=https://api.openai.com/v1 -LLM_MODEL=gpt-4o-mini -LLM_TEMPERATURE=0.7 - -# TTS -# edge: no API key needed -# openai_compatible: compatible with SiliconFlow-style endpoints -TTS_PROVIDER=openai_compatible -TTS_VOICE=anna -TTS_SPEED=1.0 - -# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible) -SILICONFLOW_API_KEY=your_siliconflow_api_key_here -SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B -SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall - -# ASR -ASR_PROVIDER=openai_compatible -# Interim cadence and minimum audio before interim decode. -ASR_INTERIM_INTERVAL_MS=500 -ASR_MIN_AUDIO_MS=300 -# ASR start gate: ignore micro-noise, then commit to one turn once started. -ASR_START_MIN_SPEECH_MS=160 -# Pre-roll protects beginning phonemes. -ASR_PRE_SPEECH_MS=240 -# Tail silence protects ending phonemes. -ASR_FINAL_TAIL_MS=120 - -# Duplex behavior -DUPLEX_ENABLED=true -# DUPLEX_GREETING=Hello! How can I help you today? -DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational. - -# Barge-in (user interrupting assistant) -# Min user speech duration needed to interrupt assistant audio. -BARGE_IN_MIN_DURATION_MS=200 -# Allowed silence during potential barge-in (ms) before reset. -BARGE_IN_SILENCE_TOLERANCE_MS=60 +# Optional: provider credentials referenced from YAML, e.g. ${LLM_API_KEY} +# LLM_API_KEY=your_llm_api_key_here +# LLM_API_URL=https://api.openai.com/v1 +# TTS_API_KEY=your_tts_api_key_here +# TTS_API_URL=https://api.example.com/v1/audio/speech +# ASR_API_KEY=your_asr_api_key_here +# ASR_API_URL=https://api.example.com/v1/audio/transcriptions # Logging LOG_LEVEL=INFO diff --git a/engine/.gitignore b/engine/.gitignore index 5cd10e8..201a051 100644 --- a/engine/.gitignore +++ b/engine/.gitignore @@ -146,3 +146,5 @@ cython_debug/ recordings/ logs/ running/ + +config/agents/default.yaml diff --git a/engine/README.md b/engine/README.md index 17d9e3a..3353270 100644 --- a/engine/README.md +++ b/engine/README.md @@ -14,6 +14,57 @@ It is currently in an early, experimental stage. uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` +使用 agent profile(推荐) + +``` +python -m app.main --agent-profile default +``` + +使用指定 YAML + +``` +python -m app.main --agent-config config/agents/default.yaml +``` + +Agent 配置路径优先级 +1. `--agent-config` +2. `--agent-profile`(映射到 `config/agents/.yaml`) +3. `AGENT_CONFIG_PATH` +4. `AGENT_PROFILE` +5. `config/agents/default.yaml`(若存在) + +说明 +- Agent 相关配置是严格模式:YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。 +- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。 +- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。 +- 现在支持在 Agent YAML 中配置 `agent.tools`(列表),用于声明运行时可调用工具。 +- 工具配置示例见 `config/agents/tools.yaml`。 + +## Backend Integration + +Engine runtime now supports adapter-based backend integration: + +- `BACKEND_MODE=auto|http|disabled` +- `BACKEND_URL` + `BACKEND_TIMEOUT_SEC` +- `HISTORY_ENABLED=true|false` + +Behavior: + +- `auto`: use HTTP backend only when `BACKEND_URL` is set, otherwise engine-only mode. +- `http`: force HTTP backend; falls back to engine-only mode when URL is missing. +- `disabled`: force engine-only mode (no backend calls). + +History write path is now asynchronous and buffered per session: + +- `HISTORY_QUEUE_MAX_SIZE` +- `HISTORY_RETRY_MAX_ATTEMPTS` +- `HISTORY_RETRY_BACKOFF_SEC` +- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC` + +This keeps turn processing responsive even when backend history APIs are slow/failing. + +Detailed notes: `docs/backend_integration.md`. + 测试 ``` @@ -28,4 +79,4 @@ python mic_client.py `/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames. -See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`. +See `docs/ws_v1_schema.md`. diff --git a/engine/app/backend_adapters.py b/engine/app/backend_adapters.py new file mode 100644 index 0000000..a05bd8f --- /dev/null +++ b/engine/app/backend_adapters.py @@ -0,0 +1,357 @@ +"""Backend adapter implementations for engine integration ports.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import aiohttp +from loguru import logger + +from app.config import settings + + +class NullBackendAdapter: + """No-op adapter for engine-only runtime without backend dependencies.""" + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + _ = assistant_id + return None + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + _ = (user_id, assistant_id, source) + return None + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + _ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms) + return False + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + _ = (call_id, status, duration_seconds) + return False + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + _ = (kb_id, query, n_results) + return [] + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + _ = tool_id + return None + + +class HistoryDisabledBackendAdapter: + """Adapter wrapper that disables history writes while keeping reads available.""" + + def __init__(self, delegate: HttpBackendAdapter | NullBackendAdapter): + self._delegate = delegate + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + return await self._delegate.fetch_assistant_config(assistant_id) + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + _ = (user_id, assistant_id, source) + return None + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + _ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms) + return False + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + _ = (call_id, status, duration_seconds) + return False + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + return await self._delegate.search_knowledge_context( + kb_id=kb_id, + query=query, + n_results=n_results, + ) + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + return await self._delegate.fetch_tool_resource(tool_id) + + +class HttpBackendAdapter: + """HTTP implementation of backend integration ports.""" + + def __init__(self, backend_url: str, timeout_sec: int = 10): + base_url = str(backend_url or "").strip().rstrip("/") + if not base_url: + raise ValueError("backend_url is required for HttpBackendAdapter") + self._base_url = base_url + self._timeout_sec = timeout_sec + + def _timeout(self) -> aiohttp.ClientTimeout: + return aiohttp.ClientTimeout(total=self._timeout_sec) + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + """Fetch assistant config payload from backend API. + + Expected response shape: + { + "assistant": {...}, + "voice": {...} | null + } + """ + url = f"{self._base_url}/api/assistants/{assistant_id}/config" + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.get(url) as resp: + if resp.status == 404: + logger.warning(f"Assistant config not found: {assistant_id}") + return None + resp.raise_for_status() + payload = await resp.json() + if not isinstance(payload, dict): + logger.warning("Assistant config payload is not a dict; ignoring") + return None + return payload + except Exception as exc: + logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") + return None + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + """Create a call record via backend history API and return call_id.""" + url = f"{self._base_url}/api/history" + payload: Dict[str, Any] = { + "user_id": user_id, + "assistant_id": assistant_id, + "source": source, + "status": "connected", + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + data = await resp.json() + call_id = str((data or {}).get("id") or "") + return call_id or None + except Exception as exc: + logger.warning(f"Failed to create history call record: {exc}") + return None + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + """Append a transcript segment to backend history.""" + if not call_id: + return False + + url = f"{self._base_url}/api/history/{call_id}/transcripts" + payload: Dict[str, Any] = { + "turn_index": turn_index, + "speaker": speaker, + "content": content, + "confidence": confidence, + "start_ms": start_ms, + "end_ms": end_ms, + "duration_ms": duration_ms, + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return True + except Exception as exc: + logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}") + return False + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + """Finalize a call record with status and duration.""" + if not call_id: + return False + + url = f"{self._base_url}/api/history/{call_id}" + payload: Dict[str, Any] = { + "status": status, + "duration_seconds": duration_seconds, + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.put(url, json=payload) as resp: + resp.raise_for_status() + return True + except Exception as exc: + logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") + return False + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + """Search backend knowledge base and return retrieval results.""" + if not kb_id or not query.strip(): + return [] + try: + safe_n_results = max(1, int(n_results)) + except (TypeError, ValueError): + safe_n_results = 5 + + url = f"{self._base_url}/api/knowledge/search" + payload: Dict[str, Any] = { + "kb_id": kb_id, + "query": query, + "nResults": safe_n_results, + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.post(url, json=payload) as resp: + if resp.status == 404: + logger.warning(f"Knowledge base not found for retrieval: {kb_id}") + return [] + resp.raise_for_status() + data = await resp.json() + if not isinstance(data, dict): + return [] + results = data.get("results", []) + if not isinstance(results, list): + return [] + return [r for r in results if isinstance(r, dict)] + except Exception as exc: + logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}") + return [] + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + """Fetch tool resource configuration from backend API.""" + if not tool_id: + return None + + url = f"{self._base_url}/api/tools/resources/{tool_id}" + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.get(url) as resp: + if resp.status == 404: + return None + resp.raise_for_status() + data = await resp.json() + return data if isinstance(data, dict) else None + except Exception as exc: + logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}") + return None + + +def build_backend_adapter( + *, + backend_url: Optional[str], + backend_mode: str = "auto", + history_enabled: bool = True, + timeout_sec: int = 10, +) -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter: + """Create backend adapter implementation based on runtime settings.""" + mode = str(backend_mode or "auto").strip().lower() + has_url = bool(str(backend_url or "").strip()) + + base_adapter: HttpBackendAdapter | NullBackendAdapter + if mode in {"disabled", "off", "none", "null", "engine_only", "engine-only"}: + base_adapter = NullBackendAdapter() + elif mode == "http": + if has_url: + base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec) + else: + logger.warning("BACKEND_MODE=http but BACKEND_URL is empty; falling back to NullBackendAdapter") + base_adapter = NullBackendAdapter() + else: + if has_url: + base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec) + else: + base_adapter = NullBackendAdapter() + + if not history_enabled: + return HistoryDisabledBackendAdapter(base_adapter) + return base_adapter + + +def build_backend_adapter_from_settings() -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter: + """Create backend adapter using current app settings.""" + return build_backend_adapter( + backend_url=settings.backend_url, + backend_mode=settings.backend_mode, + history_enabled=settings.history_enabled, + timeout_sec=settings.backend_timeout_sec, + ) diff --git a/engine/app/backend_client.py b/engine/app/backend_client.py index b750564..93ea183 100644 --- a/engine/app/backend_client.py +++ b/engine/app/backend_client.py @@ -1,56 +1,19 @@ -"""Backend API client for assistant config and history persistence.""" +"""Compatibility wrappers around backend adapter implementations.""" from __future__ import annotations from typing import Any, Dict, List, Optional -import aiohttp -from loguru import logger +from app.backend_adapters import build_backend_adapter_from_settings -from app.config import settings + +def _adapter(): + return build_backend_adapter_from_settings() async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]: - """Fetch assistant config payload from backend API. - - Expected response shape: - { - "assistant": {...}, - "voice": {...} | null - } - """ - if not settings.backend_url: - logger.warning("BACKEND_URL not set; skipping assistant config fetch") - return None - - url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config" - timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec) - - try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url) as resp: - if resp.status == 404: - logger.warning(f"Assistant config not found: {assistant_id}") - return None - resp.raise_for_status() - payload = await resp.json() - if not isinstance(payload, dict): - logger.warning("Assistant config payload is not a dict; ignoring") - return None - return payload - except Exception as exc: - logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") - return None - - -def _backend_base_url() -> Optional[str]: - if not settings.backend_url: - return None - return settings.backend_url.rstrip("/") - - -def _timeout() -> aiohttp.ClientTimeout: - return aiohttp.ClientTimeout(total=settings.backend_timeout_sec) + """Fetch assistant config payload from backend adapter.""" + return await _adapter().fetch_assistant_config(assistant_id) async def create_history_call_record( @@ -60,28 +23,11 @@ async def create_history_call_record( source: str = "debug", ) -> Optional[str]: """Create a call record via backend history API and return call_id.""" - base_url = _backend_base_url() - if not base_url: - return None - - url = f"{base_url}/api/history" - payload: Dict[str, Any] = { - "user_id": user_id, - "assistant_id": assistant_id, - "source": source, - "status": "connected", - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.post(url, json=payload) as resp: - resp.raise_for_status() - data = await resp.json() - call_id = str((data or {}).get("id") or "") - return call_id or None - except Exception as exc: - logger.warning(f"Failed to create history call record: {exc}") - return None + return await _adapter().create_call_record( + user_id=user_id, + assistant_id=assistant_id, + source=source, + ) async def add_history_transcript( @@ -96,29 +42,16 @@ async def add_history_transcript( duration_ms: Optional[int] = None, ) -> bool: """Append a transcript segment to backend history.""" - base_url = _backend_base_url() - if not base_url or not call_id: - return False - - url = f"{base_url}/api/history/{call_id}/transcripts" - payload: Dict[str, Any] = { - "turn_index": turn_index, - "speaker": speaker, - "content": content, - "confidence": confidence, - "start_ms": start_ms, - "end_ms": end_ms, - "duration_ms": duration_ms, - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.post(url, json=payload) as resp: - resp.raise_for_status() - return True - except Exception as exc: - logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}") - return False + return await _adapter().add_transcript( + call_id=call_id, + turn_index=turn_index, + speaker=speaker, + content=content, + start_ms=start_ms, + end_ms=end_ms, + confidence=confidence, + duration_ms=duration_ms, + ) async def finalize_history_call_record( @@ -128,24 +61,11 @@ async def finalize_history_call_record( duration_seconds: int, ) -> bool: """Finalize a call record with status and duration.""" - base_url = _backend_base_url() - if not base_url or not call_id: - return False - - url = f"{base_url}/api/history/{call_id}" - payload: Dict[str, Any] = { - "status": status, - "duration_seconds": duration_seconds, - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.put(url, json=payload) as resp: - resp.raise_for_status() - return True - except Exception as exc: - logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") - return False + return await _adapter().finalize_call_record( + call_id=call_id, + status=status, + duration_seconds=duration_seconds, + ) async def search_knowledge_context( @@ -155,57 +75,13 @@ async def search_knowledge_context( n_results: int = 5, ) -> List[Dict[str, Any]]: """Search backend knowledge base and return retrieval results.""" - base_url = _backend_base_url() - if not base_url: - return [] - if not kb_id or not query.strip(): - return [] - try: - safe_n_results = max(1, int(n_results)) - except (TypeError, ValueError): - safe_n_results = 5 - - url = f"{base_url}/api/knowledge/search" - payload: Dict[str, Any] = { - "kb_id": kb_id, - "query": query, - "nResults": safe_n_results, - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.post(url, json=payload) as resp: - if resp.status == 404: - logger.warning(f"Knowledge base not found for retrieval: {kb_id}") - return [] - resp.raise_for_status() - data = await resp.json() - if not isinstance(data, dict): - return [] - results = data.get("results", []) - if not isinstance(results, list): - return [] - return [r for r in results if isinstance(r, dict)] - except Exception as exc: - logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}") - return [] + return await _adapter().search_knowledge_context( + kb_id=kb_id, + query=query, + n_results=n_results, + ) async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]: """Fetch tool resource configuration from backend API.""" - base_url = _backend_base_url() - if not base_url or not tool_id: - return None - - url = f"{base_url}/api/tools/resources/{tool_id}" - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.get(url) as resp: - if resp.status == 404: - return None - resp.raise_for_status() - data = await resp.json() - return data if isinstance(data, dict) else None - except Exception as exc: - logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}") - return None + return await _adapter().fetch_tool_resource(tool_id) diff --git a/engine/app/config.py b/engine/app/config.py index 1e3e1b3..2d1a680 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -1,9 +1,360 @@ -"""Configuration management using Pydantic settings.""" +"""Configuration management using Pydantic settings and agent YAML profiles.""" + +import json +import os +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple -from typing import List, Optional from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict -import json + +try: + import yaml +except ImportError: # pragma: no cover - validated when agent YAML is used + yaml = None + + +_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}") +_DEFAULT_AGENT_CONFIG_DIR = "config/agents" +_DEFAULT_AGENT_CONFIG_FILE = "default.yaml" +_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = { + "vad": { + "type": "vad_type", + "model_path": "vad_model_path", + "threshold": "vad_threshold", + "min_speech_duration_ms": "vad_min_speech_duration_ms", + "eou_threshold_ms": "vad_eou_threshold_ms", + }, + "llm": { + "provider": "llm_provider", + "model": "llm_model", + "temperature": "llm_temperature", + "api_key": "llm_api_key", + "api_url": "llm_api_url", + }, + "tts": { + "provider": "tts_provider", + "api_key": "tts_api_key", + "api_url": "tts_api_url", + "model": "tts_model", + "voice": "tts_voice", + "speed": "tts_speed", + }, + "asr": { + "provider": "asr_provider", + "api_key": "asr_api_key", + "api_url": "asr_api_url", + "model": "asr_model", + "interim_interval_ms": "asr_interim_interval_ms", + "min_audio_ms": "asr_min_audio_ms", + "start_min_speech_ms": "asr_start_min_speech_ms", + "pre_speech_ms": "asr_pre_speech_ms", + "final_tail_ms": "asr_final_tail_ms", + }, + "duplex": { + "enabled": "duplex_enabled", + "greeting": "duplex_greeting", + "system_prompt": "duplex_system_prompt", + }, + "barge_in": { + "min_duration_ms": "barge_in_min_duration_ms", + "silence_tolerance_ms": "barge_in_silence_tolerance_ms", + }, +} +_AGENT_SETTING_KEYS = { + "vad_type", + "vad_model_path", + "vad_threshold", + "vad_min_speech_duration_ms", + "vad_eou_threshold_ms", + "llm_provider", + "llm_api_key", + "llm_api_url", + "llm_model", + "llm_temperature", + "tts_provider", + "tts_api_key", + "tts_api_url", + "tts_model", + "tts_voice", + "tts_speed", + "asr_provider", + "asr_api_key", + "asr_api_url", + "asr_model", + "asr_interim_interval_ms", + "asr_min_audio_ms", + "asr_start_min_speech_ms", + "asr_pre_speech_ms", + "asr_final_tail_ms", + "duplex_enabled", + "duplex_greeting", + "duplex_system_prompt", + "barge_in_min_duration_ms", + "barge_in_silence_tolerance_ms", + "tools", +} +_BASE_REQUIRED_AGENT_SETTING_KEYS = { + "vad_type", + "vad_model_path", + "vad_threshold", + "vad_min_speech_duration_ms", + "vad_eou_threshold_ms", + "llm_provider", + "llm_model", + "llm_temperature", + "tts_provider", + "tts_voice", + "tts_speed", + "asr_provider", + "asr_interim_interval_ms", + "asr_min_audio_ms", + "asr_start_min_speech_ms", + "asr_pre_speech_ms", + "asr_final_tail_ms", + "duplex_enabled", + "duplex_system_prompt", + "barge_in_min_duration_ms", + "barge_in_silence_tolerance_ms", +} +_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} + + +def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str: + return str(overrides.get(key) or default).strip().lower() + + +def _is_blank(value: Any) -> bool: + return value is None or (isinstance(value, str) and not value.strip()) + + +@dataclass(frozen=True) +class AgentConfigSelection: + """Resolved agent config location and how it was selected.""" + + path: Optional[Path] + source: str + + +def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]: + """Parse only agent-related CLI flags from argv.""" + config_path: Optional[str] = None + profile: Optional[str] = None + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith("--agent-config="): + config_path = arg.split("=", 1)[1].strip() or None + elif arg == "--agent-config" and i + 1 < len(argv): + config_path = argv[i + 1].strip() or None + i += 1 + elif arg.startswith("--agent-profile="): + profile = arg.split("=", 1)[1].strip() or None + elif arg == "--agent-profile" and i + 1 < len(argv): + profile = argv[i + 1].strip() or None + i += 1 + i += 1 + return config_path, profile + + +def _agent_config_dir() -> Path: + base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR)) + if not base_dir.is_absolute(): + base_dir = Path.cwd() / base_dir + return base_dir.resolve() + + +def _resolve_agent_selection( + agent_config_path: Optional[str] = None, + agent_profile: Optional[str] = None, + argv: Optional[List[str]] = None, +) -> AgentConfigSelection: + cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:])) + path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH") + profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE") + source = "none" + candidate: Optional[Path] = None + + if path_value: + source = "cli_path" if (agent_config_path or cli_path) else "env_path" + candidate = Path(path_value) + elif profile_value: + source = "cli_profile" if (agent_profile or cli_profile) else "env_profile" + candidate = _agent_config_dir() / f"{profile_value}.yaml" + else: + fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE + if fallback.exists(): + source = "default" + candidate = fallback + + if candidate is None: + raise ValueError( + "Agent YAML config is required. Provide --agent-config/--agent-profile " + "or create config/agents/default.yaml." + ) + + if not candidate.is_absolute(): + candidate = (Path.cwd() / candidate).resolve() + else: + candidate = candidate.resolve() + + if not candidate.exists(): + raise ValueError(f"Agent config file not found ({source}): {candidate}") + if not candidate.is_file(): + raise ValueError(f"Agent config path is not a file: {candidate}") + return AgentConfigSelection(path=candidate, source=source) + + +def _resolve_env_refs(value: Any) -> Any: + """Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively.""" + if isinstance(value, dict): + return {k: _resolve_env_refs(v) for k, v in value.items()} + if isinstance(value, list): + return [_resolve_env_refs(item) for item in value] + if not isinstance(value, str) or "${" not in value: + return value + + def _replace(match: re.Match[str]) -> str: + env_key = match.group(1) + default_value = match.group(2) + env_value = os.getenv(env_key) + if env_value is None: + if default_value is None: + raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}") + return default_value + return env_value + + return _ENV_REF_PATTERN.sub(_replace, value) + + +def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]: + """Normalize YAML into flat Settings fields.""" + normalized: Dict[str, Any] = {} + + for key, value in raw.items(): + if key == "siliconflow": + raise ValueError( + "Section 'siliconflow' is no longer supported. " + "Move provider-specific fields into agent.llm / agent.asr / agent.tts." + ) + if key == "tools": + if not isinstance(value, list): + raise ValueError("Agent config key 'tools' must be a list") + normalized["tools"] = value + continue + section_map = _AGENT_SECTION_KEY_MAP.get(key) + if section_map is None: + normalized[key] = value + continue + + if not isinstance(value, dict): + raise ValueError(f"Agent config section '{key}' must be a mapping") + + for nested_key, nested_value in value.items(): + mapped_key = section_map.get(nested_key) + if mapped_key is None: + raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'") + normalized[mapped_key] = nested_value + + unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS) + if unknown_keys: + raise ValueError( + "Unknown agent config keys in YAML: " + + ", ".join(unknown_keys) + ) + return normalized + + +def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]: + missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides)) + string_required = { + "vad_type", + "vad_model_path", + "llm_provider", + "llm_model", + "tts_provider", + "tts_voice", + "asr_provider", + "duplex_system_prompt", + } + for key in string_required: + if key in overrides and _is_blank(overrides.get(key)): + missing.add(key) + + llm_provider = _normalized_provider(overrides, "llm_provider", "openai") + if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai": + if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")): + missing.add("llm_api_key") + + tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible") + if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS: + if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")): + missing.add("tts_api_key") + if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")): + missing.add("tts_api_url") + if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")): + missing.add("tts_model") + + asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible") + if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS: + if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")): + missing.add("asr_api_key") + if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")): + missing.add("asr_api_url") + if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")): + missing.add("asr_model") + + return sorted(missing) + + +def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]: + if yaml is None: + raise RuntimeError( + "PyYAML is required for agent YAML configuration. Install with: pip install pyyaml" + ) + + with selection.path.open("r", encoding="utf-8") as file: + raw = yaml.safe_load(file) or {} + + if not isinstance(raw, dict): + raise ValueError(f"Agent config must be a YAML mapping: {selection.path}") + + if "agent" in raw: + agent_value = raw["agent"] + if not isinstance(agent_value, dict): + raise ValueError("The 'agent' key in YAML must be a mapping") + raw = agent_value + + resolved = _resolve_env_refs(raw) + overrides = _normalize_agent_overrides(resolved) + missing_required = _missing_required_keys(overrides) + if missing_required: + raise ValueError( + f"Missing required agent settings in YAML ({selection.path}): " + + ", ".join(missing_required) + ) + + overrides["agent_config_path"] = str(selection.path) + overrides["agent_config_source"] = selection.source + return overrides + + +def load_settings( + agent_config_path: Optional[str] = None, + agent_profile: Optional[str] = None, + argv: Optional[List[str]] = None, +) -> "Settings": + """Load settings from .env and optional agent YAML.""" + selection = _resolve_agent_selection( + agent_config_path=agent_config_path, + agent_profile=agent_profile, + argv=argv, + ) + agent_overrides = _load_agent_overrides(selection) + return Settings(**agent_overrides) class Settings(BaseSettings): @@ -37,30 +388,35 @@ class Settings(BaseSettings): vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds") vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds") - # OpenAI / LLM Configuration - openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key") - openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)") + # LLM Configuration + llm_provider: str = Field( + default="openai", + description="LLM provider (openai, openai_compatible, siliconflow)" + ) + llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key") + llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL") llm_model: str = Field(default="gpt-4o-mini", description="LLM model name") llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation") # TTS Configuration tts_provider: str = Field( default="openai_compatible", - description="TTS provider (edge, openai_compatible; siliconflow alias supported)" + description="TTS provider (edge, openai_compatible, siliconflow)" ) + tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key") + tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL") + tts_model: Optional[str] = Field(default=None, description="TTS model name") tts_voice: str = Field(default="anna", description="TTS voice name") tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier") - # SiliconFlow Configuration - siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key") - siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model") - # ASR Configuration asr_provider: str = Field( default="openai_compatible", - description="ASR provider (openai_compatible, buffered; siliconflow alias supported)" + description="ASR provider (openai_compatible, buffered, siliconflow)" ) - siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model") + asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key") + asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL") + asr_model: Optional[str] = Field(default=None, description="ASR model name") asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") asr_start_min_speech_ms: int = Field( @@ -94,6 +450,10 @@ class Settings(BaseSettings): description="How much silence (ms) is tolerated during potential barge-in before reset" ) + # Optional tool declarations from agent YAML. + # Supports OpenAI function schema style entries and/or shorthand string names. + tools: List[Any] = Field(default_factory=list, description="Default tool definitions for runtime") + # Logging log_level: str = Field(default="INFO", description="Logging level") log_format: str = Field(default="json", description="Log format (json or text)") @@ -118,9 +478,25 @@ class Settings(BaseSettings): ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set") # Backend bridge configuration (for call/transcript persistence) + backend_mode: str = Field( + default="auto", + description="Backend integration mode: auto | http | disabled" + ) backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)") backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds") + history_enabled: bool = Field(default=True, description="Enable history write bridge") history_default_user_id: int = Field(default=1, description="Fallback user_id for history records") + history_queue_max_size: int = Field(default=256, description="Max buffered transcript writes per session") + history_retry_max_attempts: int = Field(default=2, description="Retry attempts for each transcript write") + history_retry_backoff_sec: float = Field(default=0.2, description="Base retry backoff for transcript writes") + history_finalize_drain_timeout_sec: float = Field( + default=1.5, + description="Max wait before finalizing history when queue is still draining" + ) + + # Agent YAML metadata + agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path") + agent_config_source: str = Field(default="none", description="How the agent YAML was selected") @property def chunk_size_bytes(self) -> int: @@ -146,7 +522,7 @@ class Settings(BaseSettings): # Global settings instance -settings = Settings() +settings = load_settings() def get_settings() -> Settings: diff --git a/engine/app/main.py b/engine/app/main.py index 259204c..c13daba 100644 --- a/engine/app/main.py +++ b/engine/app/main.py @@ -20,11 +20,11 @@ except ImportError: logger.warning("aiortc not available - WebRTC endpoint will be disabled") from app.config import settings +from app.backend_adapters import build_backend_adapter_from_settings from core.transports import SocketTransport, WebRtcTransport, BaseTransport from core.session import Session from processors.tracks import Resampled16kTrack from core.events import get_event_bus, reset_event_bus -from models.ws_v1 import ev # Check interval for heartbeat/timeout (seconds) _HEARTBEAT_CHECK_INTERVAL_SEC = 5 @@ -54,9 +54,7 @@ async def heartbeat_and_timeout_task( break if now - last_heartbeat_at[0] >= heartbeat_interval_sec: try: - await transport.send_event({ - **ev("heartbeat"), - }) + await session.send_heartbeat() last_heartbeat_at[0] = now except Exception as e: logger.debug(f"Session {session_id}: heartbeat send failed: {e}") @@ -78,6 +76,7 @@ app.add_middleware( # Active sessions storage active_sessions: Dict[str, Session] = {} +backend_gateway = build_backend_adapter_from_settings() # Configure logging logger.remove() @@ -167,7 +166,7 @@ async def websocket_endpoint(websocket: WebSocket): # Create transport and session transport = SocketTransport(websocket) - session = Session(session_id, transport) + session = Session(session_id, transport, backend_gateway=backend_gateway) active_sessions[session_id] = session logger.info(f"WebSocket connection established: {session_id}") @@ -246,7 +245,7 @@ async def webrtc_endpoint(websocket: WebSocket): # Create transport and session transport = WebRtcTransport(websocket, pc) - session = Session(session_id, transport) + session = Session(session_id, transport, backend_gateway=backend_gateway) active_sessions[session_id] = session logger.info(f"WebRTC connection established: {session_id}") @@ -360,6 +359,12 @@ async def startup_event(): logger.info(f"Server: {settings.host}:{settings.port}") logger.info(f"Sample rate: {settings.sample_rate} Hz") logger.info(f"VAD model: {settings.vad_model_path}") + if settings.agent_config_path: + logger.info( + f"Agent config loaded ({settings.agent_config_source}): {settings.agent_config_path}" + ) + else: + logger.info("Agent config: none (using .env/default agent values)") @app.on_event("shutdown") diff --git a/engine/config/agents/example.yaml b/engine/config/agents/example.yaml new file mode 100644 index 0000000..114830e --- /dev/null +++ b/engine/config/agents/example.yaml @@ -0,0 +1,50 @@ +# Agent behavior configuration (safe to edit per profile) +# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers). +# Infra/server/network settings should stay in .env. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + # Required: no fallback. You can still reference env explicitly. + api_key: your_llm_api_key + # Optional for OpenAI-compatible endpoints: + api_url: https://api.qnaigc.com/v1 + + tts: + # provider: edge | openai_compatible | siliconflow + provider: openai_compatible + api_key: your_tts_api_key + api_url: https://api.siliconflow.cn/v1/audio/speech + model: FunAudioLLM/CosyVoice2-0.5B + voice: anna + speed: 1.0 + + asr: + # provider: buffered | openai_compatible | siliconflow + provider: openai_compatible + api_key: you_asr_api_key + api_url: https://api.siliconflow.cn/v1/audio/transcriptions + model: FunAudioLLM/SenseVoiceSmall + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational. + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 diff --git a/engine/config/agents/tools.yaml b/engine/config/agents/tools.yaml new file mode 100644 index 0000000..9734bff --- /dev/null +++ b/engine/config/agents/tools.yaml @@ -0,0 +1,73 @@ +# Agent behavior configuration with tool declarations. +# This profile is an example only. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + api_key: your_llm_api_key + api_url: https://api.qnaigc.com/v1 + + tts: + # provider: edge | openai_compatible | siliconflow + provider: openai_compatible + api_key: your_tts_api_key + api_url: https://api.siliconflow.cn/v1/audio/speech + model: FunAudioLLM/CosyVoice2-0.5B + voice: anna + speed: 1.0 + + asr: + # provider: buffered | openai_compatible | siliconflow + provider: openai_compatible + api_key: your_asr_api_key + api_url: https://api.siliconflow.cn/v1/audio/transcriptions + model: FunAudioLLM/SenseVoiceSmall + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: You are a helpful voice assistant with tool-calling support. + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 + + # Tool declarations consumed by the engine at startup. + # - String form enables built-in/default tool schema when available. + # - Object form provides OpenAI function schema + executor hint. + tools: + - current_time + - calculator + - name: weather + description: Get weather by city name. + parameters: + type: object + properties: + city: + type: string + description: City name, for example "San Francisco". + required: [city] + executor: server + - name: open_map + description: Open map app on the client device. + parameters: + type: object + properties: + query: + type: string + required: [query] + executor: client diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 508ba2b..5722cea 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -14,7 +14,8 @@ event-driven design. import asyncio import json import time -from typing import Any, Dict, List, Optional, Tuple +import uuid +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple import numpy as np from loguru import logger @@ -59,6 +60,12 @@ class DuplexPipeline: _MIN_SPLIT_SPOKEN_CHARS = 6 _TOOL_WAIT_TIMEOUT_SECONDS = 15.0 _SERVER_TOOL_TIMEOUT_SECONDS = 15.0 + TRACK_AUDIO_IN = "audio_in" + TRACK_AUDIO_OUT = "audio_out" + TRACK_CONTROL = "control" + _PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms + _ASR_DELTA_THROTTLE_MS = 300 + _LLM_DELTA_THROTTLE_MS = 80 _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { "current_time": { "name": "current_time", @@ -79,7 +86,16 @@ class DuplexPipeline: tts_service: Optional[BaseTTSService] = None, asr_service: Optional[BaseASRService] = None, system_prompt: Optional[str] = None, - greeting: Optional[str] = None + greeting: Optional[str] = None, + knowledge_searcher: Optional[ + Callable[..., Awaitable[List[Dict[str, Any]]]] + ] = None, + tool_resource_resolver: Optional[ + Callable[[str], Awaitable[Optional[Dict[str, Any]]]] + ] = None, + server_tool_executor: Optional[ + Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]] + ] = None, ): """ Initialize duplex pipeline. @@ -96,6 +112,9 @@ class DuplexPipeline: self.transport = transport self.session_id = session_id self.event_bus = get_event_bus() + self.track_audio_in = self.TRACK_AUDIO_IN + self.track_audio_out = self.TRACK_AUDIO_OUT + self.track_control = self.TRACK_CONTROL # Initialize VAD self.vad_model = SileroVAD( @@ -117,9 +136,14 @@ class DuplexPipeline: self.llm_service = llm_service self.tts_service = tts_service self.asr_service = asr_service # Will be initialized in start() + self._knowledge_searcher = knowledge_searcher + self._tool_resource_resolver = tool_resource_resolver + self._server_tool_executor = server_tool_executor # Track last sent transcript to avoid duplicates self._last_sent_transcript = "" + self._pending_transcript_delta: str = "" + self._last_transcript_delta_emit_ms: float = 0.0 # Conversation manager self.conversation = ConversationManager( @@ -153,6 +177,7 @@ class DuplexPipeline: self._outbound_seq = 0 self._outbound_task: Optional[asyncio.Task] = None self._drop_outbound_audio = False + self._audio_out_frame_buffer: bytes = b"" # Interruption handling self._interrupt_event = asyncio.Event() @@ -181,14 +206,48 @@ class DuplexPipeline: self._runtime_barge_in_min_duration_ms: Optional[int] = None self._runtime_knowledge: Dict[str, Any] = {} self._runtime_knowledge_base_id: Optional[str] = None - self._runtime_tools: List[Any] = [] + raw_default_tools = settings.tools if isinstance(settings.tools, list) else [] + self._runtime_tools: List[Any] = list(raw_default_tools) self._runtime_tool_executor: Dict[str, str] = {} self._pending_tool_waiters: Dict[str, asyncio.Future] = {} self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._completed_tool_call_ids: set[str] = set() + self._pending_client_tool_call_ids: set[str] = set() + self._next_seq: Optional[Callable[[], int]] = None + self._local_seq: int = 0 + + # Cross-service correlation IDs + self._turn_count: int = 0 + self._response_count: int = 0 + self._tts_count: int = 0 + self._utterance_count: int = 0 + self._current_turn_id: Optional[str] = None + self._current_utterance_id: Optional[str] = None + self._current_response_id: Optional[str] = None + self._current_tts_id: Optional[str] = None + self._pending_llm_delta: str = "" + self._last_llm_delta_emit_ms: float = 0.0 + + self._runtime_tool_executor = self._resolved_tool_executor_map() + + if self._server_tool_executor is None: + if self._tool_resource_resolver: + async def _executor(call: Dict[str, Any]) -> Dict[str, Any]: + return await execute_server_tool( + call, + tool_resource_fetcher=self._tool_resource_resolver, + ) + + self._server_tool_executor = _executor + else: + self._server_tool_executor = execute_server_tool logger.info(f"DuplexPipeline initialized for session {session_id}") + def set_event_sequence_provider(self, provider: Callable[[], int]) -> None: + """Use session-scoped monotonic sequence provider for envelope events.""" + self._next_seq = provider + def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None: """ Apply runtime overrides from WS session.start metadata. @@ -276,6 +335,136 @@ class DuplexPipeline: if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"): self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) + def resolved_runtime_config(self) -> Dict[str, Any]: + """Return current effective runtime configuration without secrets.""" + llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower() + llm_base_url = ( + self._runtime_llm.get("baseUrl") + or settings.llm_api_url + or self._default_llm_base_url(llm_provider) + ) + tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower() + asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower() + output_mode = str(self._runtime_output.get("mode") or "").strip().lower() + if not output_mode: + output_mode = "audio" if self._tts_output_enabled() else "text" + + return { + "output": {"mode": output_mode}, + "services": { + "llm": { + "provider": llm_provider, + "model": str(self._runtime_llm.get("model") or settings.llm_model), + "baseUrl": llm_base_url, + }, + "asr": { + "provider": asr_provider, + "model": str(self._runtime_asr.get("model") or settings.asr_model or ""), + "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), + "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), + }, + "tts": { + "enabled": self._tts_output_enabled(), + "provider": tts_provider, + "model": str(self._runtime_tts.get("model") or settings.tts_model or ""), + "voice": str(self._runtime_tts.get("voice") or settings.tts_voice), + "speed": float(self._runtime_tts.get("speed") or settings.tts_speed), + }, + }, + "tools": { + "allowlist": self._resolved_tool_allowlist(), + }, + "tracks": { + "audio_in": self.track_audio_in, + "audio_out": self.track_audio_out, + "control": self.track_control, + }, + } + + def _next_event_seq(self) -> int: + if self._next_seq: + return self._next_seq() + self._local_seq += 1 + return self._local_seq + + def _event_source(self, event_type: str) -> str: + if event_type.startswith("transcript.") or event_type.startswith("input.speech_"): + return "asr" + if event_type.startswith("assistant.response."): + return "llm" + if event_type.startswith("assistant.tool_"): + return "tool" + if event_type.startswith("output.audio.") or event_type == "metrics.ttfb": + return "tts" + return "system" + + def _new_id(self, prefix: str, counter: int) -> str: + return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}" + + def _start_turn(self) -> str: + self._turn_count += 1 + self._current_turn_id = self._new_id("turn", self._turn_count) + self._current_utterance_id = None + self._current_response_id = None + self._current_tts_id = None + return self._current_turn_id + + def _start_response(self) -> str: + self._response_count += 1 + self._current_response_id = self._new_id("resp", self._response_count) + self._current_tts_id = None + return self._current_response_id + + def _start_tts(self) -> str: + self._tts_count += 1 + self._current_tts_id = self._new_id("tts", self._tts_count) + return self._current_tts_id + + def _finalize_utterance(self) -> str: + if self._current_utterance_id: + return self._current_utterance_id + self._utterance_count += 1 + self._current_utterance_id = self._new_id("utt", self._utterance_count) + if not self._current_turn_id: + self._start_turn() + return self._current_utterance_id + + def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: + event_type = str(event.get("type") or "") + source = str(event.get("source") or self._event_source(event_type)) + track_id = event.get("trackId") + if not track_id: + if source == "asr": + track_id = self.track_audio_in + elif source in {"llm", "tts", "tool"}: + track_id = self.track_audio_out + else: + track_id = self.track_control + + data = event.get("data") + if not isinstance(data, dict): + data = {} + if self._current_turn_id: + data.setdefault("turn_id", self._current_turn_id) + if self._current_utterance_id: + data.setdefault("utterance_id", self._current_utterance_id) + if self._current_response_id: + data.setdefault("response_id", self._current_response_id) + if self._current_tts_id: + data.setdefault("tts_id", self._current_tts_id) + + for k, v in event.items(): + if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}: + continue + data.setdefault(k, v) + + event["sessionId"] = self.session_id + event["seq"] = self._next_event_seq() + event["source"] = source + event["trackId"] = track_id + event["data"] = data + return event + @staticmethod def _coerce_bool(value: Any) -> Optional[bool]: if isinstance(value, bool): @@ -295,6 +484,18 @@ class DuplexPipeline: normalized = str(provider or "").strip().lower() return normalized in {"openai_compatible", "openai-compatible", "siliconflow"} + @staticmethod + def _is_llm_provider_supported(provider: Any) -> bool: + normalized = str(provider or "").strip().lower() + return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"} + + @staticmethod + def _default_llm_base_url(provider: Any) -> Optional[str]: + normalized = str(provider or "").strip().lower() + if normalized == "siliconflow": + return "https://api.siliconflow.cn/v1" + return None + def _tts_output_enabled(self) -> bool: enabled = self._coerce_bool(self._runtime_tts.get("enabled")) if enabled is not None: @@ -370,20 +571,25 @@ class DuplexPipeline: try: # Connect LLM service if not self.llm_service: - llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key - llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url + llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower() + llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key + llm_base_url = ( + self._runtime_llm.get("baseUrl") + or settings.llm_api_url + or self._default_llm_base_url(llm_provider) + ) llm_model = self._runtime_llm.get("model") or settings.llm_model - llm_provider = (self._runtime_llm.get("provider") or "openai").lower() - if llm_provider == "openai" and llm_api_key: + if self._is_llm_provider_supported(llm_provider) and llm_api_key: self.llm_service = OpenAILLMService( api_key=llm_api_key, base_url=llm_base_url, model=llm_model, knowledge_config=self._resolved_knowledge_config(), + knowledge_searcher=self._knowledge_searcher, ) else: - logger.warning("No OpenAI API key - using mock LLM") + logger.warning("LLM provider unsupported or API key missing - using mock LLM") self.llm_service = MockLLMService() if hasattr(self.llm_service, "set_knowledge_config"): @@ -399,20 +605,22 @@ class DuplexPipeline: if tts_output_enabled: if not self.tts_service: tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() - tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key + tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key + tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url tts_voice = self._runtime_tts.get("voice") or settings.tts_voice - tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model + tts_model = self._runtime_tts.get("model") or settings.tts_model tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) if self._is_openai_compatible_provider(tts_provider) and tts_api_key: self.tts_service = OpenAICompatibleTTSService( api_key=tts_api_key, + api_url=tts_api_url, voice=tts_voice, - model=tts_model, + model=tts_model or "FunAudioLLM/CosyVoice2-0.5B", sample_rate=settings.sample_rate, speed=tts_speed ) - logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)") + logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})") else: self.tts_service = EdgeTTSService( voice=tts_voice, @@ -435,21 +643,23 @@ class DuplexPipeline: # Connect ASR service if not self.asr_service: asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower() - asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key - asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model + asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key + asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url + asr_model = self._runtime_asr.get("model") or settings.asr_model asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) if self._is_openai_compatible_provider(asr_provider) and asr_api_key: self.asr_service = OpenAICompatibleASRService( api_key=asr_api_key, - model=asr_model, + api_url=asr_api_url, + model=asr_model or "FunAudioLLM/SenseVoiceSmall", sample_rate=settings.sample_rate, interim_interval_ms=asr_interim_interval, min_audio_for_interim_ms=asr_min_audio_ms, on_transcript=self._on_transcript_callback ) - logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)") + logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})") else: self.asr_service = BufferedASRService( sample_rate=settings.sample_rate @@ -472,11 +682,13 @@ class DuplexPipeline: greeting_to_speak = generated_greeting self.conversation.greeting = generated_greeting if greeting_to_speak: + self._start_turn() + self._start_response() await self._send_event( ev( "assistant.response.final", text=greeting_to_speak, - trackId=self.session_id, + trackId=self.track_audio_out, ), priority=20, ) @@ -494,10 +706,58 @@ class DuplexPipeline: await self._outbound_q.put((priority, self._outbound_seq, kind, payload)) async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None: - await self._enqueue_outbound("event", event, priority) + await self._enqueue_outbound("event", self._envelope_event(event), priority) async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None: - await self._enqueue_outbound("audio", pcm_bytes, priority) + if not pcm_bytes: + return + self._audio_out_frame_buffer += pcm_bytes + while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES: + frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES] + self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :] + await self._enqueue_outbound("audio", frame, priority) + + async def _flush_audio_out_frames(self, priority: int = 50) -> None: + """Flush remaining outbound audio as one padded 20ms PCM frame.""" + if not self._audio_out_frame_buffer: + return + tail = self._audio_out_frame_buffer + self._audio_out_frame_buffer = b"" + if len(tail) < self._PCM_FRAME_BYTES: + tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail))) + await self._enqueue_outbound("audio", tail, priority) + + async def _emit_transcript_delta(self, text: str) -> None: + await self._send_event( + { + **ev( + "transcript.delta", + trackId=self.track_audio_in, + text=text, + ) + }, + priority=30, + ) + + async def _emit_llm_delta(self, text: str) -> None: + await self._send_event( + { + **ev( + "assistant.response.delta", + trackId=self.track_audio_out, + text=text, + ) + }, + priority=20, + ) + + async def _flush_pending_llm_delta(self) -> None: + if not self._pending_llm_delta: + return + chunk = self._pending_llm_delta + self._pending_llm_delta = "" + self._last_llm_delta_emit_ms = time.monotonic() * 1000.0 + await self._emit_llm_delta(chunk) async def _outbound_loop(self) -> None: """Single sender loop that enforces priority for interrupt events.""" @@ -546,13 +806,13 @@ class DuplexPipeline: # Emit VAD event await self.event_bus.publish(event_type, { - "trackId": self.session_id, + "trackId": self.track_audio_in, "probability": probability }) await self._send_event( ev( "input.speech_started" if event_type == "speaking" else "input.speech_stopped", - trackId=self.session_id, + trackId=self.track_audio_in, probability=probability, ), priority=30, @@ -661,6 +921,9 @@ class DuplexPipeline: # Cancel any current speaking await self._stop_current_speech() + self._start_turn() + self._finalize_utterance() + # Start new turn await self.conversation.end_user_turn(text) self._current_turn_task = asyncio.create_task(self._handle_turn(text)) @@ -683,24 +946,45 @@ class DuplexPipeline: if text == self._last_sent_transcript and not is_final: return + now_ms = time.monotonic() * 1000.0 self._last_sent_transcript = text - # Send transcript event to client - await self._send_event({ - **ev( - "transcript.final" if is_final else "transcript.delta", - trackId=self.session_id, - text=text, + if is_final: + self._pending_transcript_delta = "" + self._last_transcript_delta_emit_ms = 0.0 + await self._send_event( + { + **ev( + "transcript.final", + trackId=self.track_audio_in, + text=text, + ) + }, + priority=30, ) - }, priority=30) + logger.debug(f"Sent transcript (final): {text[:50]}...") + return + + self._pending_transcript_delta = text + should_emit = ( + self._last_transcript_delta_emit_ms <= 0.0 + or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS + ) + if should_emit and self._pending_transcript_delta: + delta = self._pending_transcript_delta + self._pending_transcript_delta = "" + self._last_transcript_delta_emit_ms = now_ms + await self._emit_transcript_delta(delta) if not is_final: logger.info(f"[ASR] ASR interim: {text[:100]}") - logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") + logger.debug(f"Sent transcript (interim): {text[:50]}...") async def _on_speech_start(self) -> None: """Handle user starting to speak.""" if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED): + self._start_turn() + self._finalize_utterance() await self.conversation.start_user_turn() self._audio_buffer = b"" self._last_sent_transcript = "" @@ -779,6 +1063,7 @@ class DuplexPipeline: return logger.info(f"[EOU] Detected - user said: {user_text[:100]}...") + self._finalize_utterance() # For ASR backends that already emitted final via callback, # avoid duplicating transcript.final on EOU. @@ -786,7 +1071,7 @@ class DuplexPipeline: await self._send_event({ **ev( "transcript.final", - trackId=self.session_id, + trackId=self.track_audio_in, text=user_text, ) }, priority=25) @@ -794,6 +1079,8 @@ class DuplexPipeline: # Clear buffers self._audio_buffer = b"" self._last_sent_transcript = "" + self._pending_transcript_delta = "" + self._last_transcript_delta_emit_ms = 0.0 self._asr_capture_active = False self._pending_speech_audio = b"" @@ -881,6 +1168,23 @@ class DuplexPipeline: result[name] = executor return result + def _resolved_tool_allowlist(self) -> List[str]: + names: set[str] = set() + for item in self._runtime_tools: + if isinstance(item, str): + name = item.strip() + if name: + names.add(name) + continue + if not isinstance(item, dict): + continue + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + names.add(str(fn.get("name")).strip()) + elif item.get("name"): + names.add(str(item.get("name")).strip()) + return sorted([name for name in names if name]) + def _tool_name(self, tool_call: Dict[str, Any]) -> str: fn = tool_call.get("function") if isinstance(fn, dict): @@ -894,6 +1198,44 @@ class DuplexPipeline: # Default to server execution unless explicitly marked as client. return "server" + def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: + fn = tool_call.get("function") + if not isinstance(fn, dict): + return {} + raw = fn.get("arguments") + if isinstance(raw, dict): + return raw + if isinstance(raw, str) and raw.strip(): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else {"raw": raw} + except Exception: + return {"raw": raw} + return {} + + def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]: + status = result.get("status") if isinstance(result.get("status"), dict) else {} + status_code = int(status.get("code") or 0) if status else 0 + status_message = str(status.get("message") or "") if status else "" + tool_call_id = str(result.get("tool_call_id") or result.get("id") or "") + tool_name = str(result.get("name") or "unknown_tool") + ok = bool(200 <= status_code < 300) + retryable = status_code >= 500 or status_code in {429, 408} + error: Optional[Dict[str, Any]] = None + if not ok: + error = { + "code": status_code or 500, + "message": status_message or "tool_execution_failed", + "retryable": retryable, + } + return { + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "ok": ok, + "error": error, + "status": {"code": status_code, "message": status_message}, + } + async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None: tool_name = str(result.get("name") or "unknown_tool") call_id = str(result.get("tool_call_id") or result.get("id") or "") @@ -904,12 +1246,17 @@ class DuplexPipeline: f"[Tool] emit result source={source} name={tool_name} call_id={call_id} " f"status={status_code} {status_message}".strip() ) + normalized = self._normalize_tool_result(result) await self._send_event( { **ev( "assistant.tool_result", - trackId=self.session_id, + trackId=self.track_audio_out, source=source, + tool_call_id=normalized["tool_call_id"], + tool_name=normalized["tool_name"], + ok=normalized["ok"], + error=normalized["error"], result=result, ) }, @@ -927,6 +1274,9 @@ class DuplexPipeline: call_id = str(item.get("tool_call_id") or item.get("id") or "").strip() if not call_id: continue + if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids: + logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}") + continue if call_id in self._completed_tool_call_ids: logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}") continue @@ -972,6 +1322,7 @@ class DuplexPipeline: } finally: self._pending_tool_waiters.pop(call_id, None) + self._pending_client_tool_call_ids.discard(call_id) def _normalize_stream_event(self, item: Any) -> LLMStreamEvent: if isinstance(item, LLMStreamEvent): @@ -998,6 +1349,11 @@ class DuplexPipeline: user_text: User's transcribed text """ try: + if not self._current_turn_id: + self._start_turn() + if not self._current_utterance_id: + self._finalize_utterance() + self._start_response() # Start latency tracking self._turn_start_time = time.time() self._first_audio_sent = False @@ -1012,6 +1368,8 @@ class DuplexPipeline: self._drop_outbound_audio = False first_audio_sent = False + self._pending_llm_delta = "" + self._last_llm_delta_emit_ms = 0.0 for _ in range(max_rounds): if self._interrupt_event.is_set(): break @@ -1028,6 +1386,7 @@ class DuplexPipeline: event = self._normalize_stream_event(raw_event) if event.type == "tool_call": + await self._flush_pending_llm_delta() tool_call = event.tool_call if isinstance(event.tool_call, dict) else None if not tool_call: continue @@ -1045,11 +1404,19 @@ class DuplexPipeline: f"executor={executor} args={args_preview}" ) tool_calls.append(enriched_tool_call) + tool_arguments = self._tool_arguments(enriched_tool_call) + if executor == "client" and call_id: + self._pending_client_tool_call_ids.add(call_id) await self._send_event( { **ev( "assistant.tool_call", - trackId=self.session_id, + trackId=self.track_audio_out, + tool_call_id=call_id, + tool_name=tool_name, + arguments=tool_arguments, + executor=executor, + timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000), tool_call=enriched_tool_call, ) }, @@ -1071,19 +1438,13 @@ class DuplexPipeline: round_response += text_chunk sentence_buffer += text_chunk await self.conversation.update_assistant_text(text_chunk) - - await self._send_event( - { - **ev( - "assistant.response.delta", - trackId=self.session_id, - text=text_chunk, - ) - }, - # Keep delta/final on the same event priority so FIFO seq - # preserves stream order (avoid late-delta after final). - priority=20, - ) + self._pending_llm_delta += text_chunk + now_ms = time.monotonic() * 1000.0 + if ( + self._last_llm_delta_emit_ms <= 0.0 + or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS + ): + await self._flush_pending_llm_delta() while True: split_result = extract_tts_sentence( @@ -1112,11 +1473,12 @@ class DuplexPipeline: if self._tts_output_enabled() and not self._interrupt_event.is_set(): if not first_audio_sent: + self._start_tts() await self._send_event( { **ev( "output.audio.start", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10, @@ -1130,6 +1492,7 @@ class DuplexPipeline: ) remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() + await self._flush_pending_llm_delta() if ( self._tts_output_enabled() and remaining_text @@ -1137,11 +1500,12 @@ class DuplexPipeline: and not self._interrupt_event.is_set() ): if not first_audio_sent: + self._start_tts() await self._send_event( { **ev( "output.audio.start", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10, @@ -1172,7 +1536,7 @@ class DuplexPipeline: try: result = await asyncio.wait_for( - execute_server_tool(call), + self._server_tool_executor(call), timeout=self._SERVER_TOOL_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: @@ -1204,11 +1568,12 @@ class DuplexPipeline: ] if full_response and not self._interrupt_event.is_set(): + await self._flush_pending_llm_delta() await self._send_event( { **ev( "assistant.response.final", - trackId=self.session_id, + trackId=self.track_audio_out, text=full_response, ) }, @@ -1217,10 +1582,11 @@ class DuplexPipeline: # Send track end if first_audio_sent: + await self._flush_audio_out_frames(priority=50) await self._send_event({ **ev( "output.audio.end", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10) @@ -1241,6 +1607,8 @@ class DuplexPipeline: self._barge_in_speech_start_time = None self._barge_in_speech_frames = 0 self._barge_in_silence_frames = 0 + self._current_response_id = None + self._current_tts_id = None async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None: """ @@ -1277,7 +1645,7 @@ class DuplexPipeline: await self._send_event({ **ev( "metrics.ttfb", - trackId=self.session_id, + trackId=self.track_audio_out, latencyMs=round(ttfb_ms), ) }, priority=25) @@ -1354,10 +1722,11 @@ class DuplexPipeline: first_audio_sent = False # Send track start event + self._start_tts() await self._send_event({ **ev( "output.audio.start", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10) @@ -1379,7 +1748,7 @@ class DuplexPipeline: await self._send_event({ **ev( "metrics.ttfb", - trackId=self.session_id, + trackId=self.track_audio_out, latencyMs=round(ttfb_ms), ) }, priority=25) @@ -1391,10 +1760,11 @@ class DuplexPipeline: await asyncio.sleep(0.01) # Send track end event + await self._flush_audio_out_frames(priority=50) await self._send_event({ **ev( "output.audio.end", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=10) @@ -1422,13 +1792,14 @@ class DuplexPipeline: self._interrupt_event.set() self._is_bot_speaking = False self._drop_outbound_audio = True + self._audio_out_frame_buffer = b"" # Send interrupt event to client IMMEDIATELY # This must happen BEFORE canceling services, so client knows to discard in-flight audio await self._send_event({ **ev( "response.interrupted", - trackId=self.session_id, + trackId=self.track_audio_out, ) }, priority=0) @@ -1455,6 +1826,7 @@ class DuplexPipeline: async def _stop_current_speech(self) -> None: """Stop any current speech task.""" self._drop_outbound_audio = True + self._audio_out_frame_buffer = b"" if self._current_turn_task and not self._current_turn_task.done(): self._interrupt_event.set() self._current_turn_task.cancel() diff --git a/engine/core/history_bridge.py b/engine/core/history_bridge.py new file mode 100644 index 0000000..ead9a3b --- /dev/null +++ b/engine/core/history_bridge.py @@ -0,0 +1,244 @@ +"""Async history bridge for non-blocking transcript persistence.""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any, Optional + +from loguru import logger + + +@dataclass +class _HistoryTranscriptJob: + call_id: str + turn_index: int + speaker: str + content: str + start_ms: int + end_ms: int + duration_ms: int + + +class SessionHistoryBridge: + """Session-scoped buffered history writer with background retries.""" + + _STOP_SENTINEL = object() + + def __init__( + self, + *, + history_writer: Any, + enabled: bool, + queue_max_size: int, + retry_max_attempts: int, + retry_backoff_sec: float, + finalize_drain_timeout_sec: float, + ): + self._history_writer = history_writer + self._enabled = bool(enabled and history_writer is not None) + self._queue_max_size = max(1, int(queue_max_size)) + self._retry_max_attempts = max(0, int(retry_max_attempts)) + self._retry_backoff_sec = max(0.0, float(retry_backoff_sec)) + self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec)) + + self._call_id: Optional[str] = None + self._turn_index: int = 0 + self._started_mono: Optional[float] = None + self._finalized: bool = False + self._worker_task: Optional[asyncio.Task] = None + self._finalize_lock = asyncio.Lock() + self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size) + + @property + def enabled(self) -> bool: + return self._enabled + + @property + def call_id(self) -> Optional[str]: + return self._call_id + + async def start_call( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str, + ) -> Optional[str]: + """Create remote call record and start background worker.""" + if not self._enabled or self._call_id: + return self._call_id + + call_id = await self._history_writer.create_call_record( + user_id=user_id, + assistant_id=assistant_id, + source=source, + ) + if not call_id: + return None + + self._call_id = str(call_id) + self._turn_index = 0 + self._finalized = False + self._started_mono = time.monotonic() + self._ensure_worker() + return self._call_id + + def elapsed_ms(self) -> int: + if self._started_mono is None: + return 0 + return max(0, int((time.monotonic() - self._started_mono) * 1000)) + + def enqueue_turn(self, *, role: str, text: str) -> bool: + """Queue one transcript write without blocking the caller.""" + if not self._enabled or not self._call_id or self._finalized: + return False + + content = str(text or "").strip() + if not content: + return False + + speaker = "human" if str(role or "").strip().lower() == "user" else "ai" + end_ms = self.elapsed_ms() + estimated_duration_ms = max(300, min(12000, len(content) * 80)) + start_ms = max(0, end_ms - estimated_duration_ms) + + job = _HistoryTranscriptJob( + call_id=self._call_id, + turn_index=self._turn_index, + speaker=speaker, + content=content, + start_ms=start_ms, + end_ms=end_ms, + duration_ms=max(1, end_ms - start_ms), + ) + self._turn_index += 1 + self._ensure_worker() + + try: + self._queue.put_nowait(job) + return True + except asyncio.QueueFull: + logger.warning( + "History queue full; dropping transcript call_id={} turn={}", + self._call_id, + job.turn_index, + ) + return False + + async def finalize(self, *, status: str) -> bool: + """Finalize history record once; waits briefly for queue drain.""" + if not self._enabled or not self._call_id: + return False + + async with self._finalize_lock: + if self._finalized: + return True + + await self._drain_queue() + ok = await self._history_writer.finalize_call_record( + call_id=self._call_id, + status=status, + duration_seconds=self.duration_seconds(), + ) + if ok: + self._finalized = True + await self._stop_worker() + return ok + + async def shutdown(self) -> None: + """Stop worker task and release queue resources.""" + await self._stop_worker() + + def duration_seconds(self) -> int: + if self._started_mono is None: + return 0 + return max(0, int(time.monotonic() - self._started_mono)) + + def _ensure_worker(self) -> None: + if self._worker_task and not self._worker_task.done(): + return + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def _drain_queue(self) -> None: + if self._finalize_drain_timeout_sec <= 0: + return + try: + await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec) + except asyncio.TimeoutError: + logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec) + + async def _stop_worker(self) -> None: + task = self._worker_task + if not task: + return + if task.done(): + self._worker_task = None + return + + sent = False + try: + self._queue.put_nowait(self._STOP_SENTINEL) + sent = True + except asyncio.QueueFull: + pass + + if not sent: + try: + await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5) + except asyncio.TimeoutError: + task.cancel() + + try: + await asyncio.wait_for(task, timeout=1.5) + except asyncio.TimeoutError: + task.cancel() + try: + await task + except Exception: + pass + except asyncio.CancelledError: + pass + finally: + self._worker_task = None + + async def _worker_loop(self) -> None: + while True: + item = await self._queue.get() + try: + if item is self._STOP_SENTINEL: + return + + assert isinstance(item, _HistoryTranscriptJob) + await self._write_with_retry(item) + except Exception as exc: + logger.warning("History worker write failed unexpectedly: {}", exc) + finally: + self._queue.task_done() + + async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool: + for attempt in range(self._retry_max_attempts + 1): + ok = await self._history_writer.add_transcript( + call_id=job.call_id, + turn_index=job.turn_index, + speaker=job.speaker, + content=job.content, + start_ms=job.start_ms, + end_ms=job.end_ms, + duration_ms=job.duration_ms, + ) + if ok: + return True + + if attempt >= self._retry_max_attempts: + logger.warning( + "History write dropped after retries call_id={} turn={}", + job.call_id, + job.turn_index, + ) + return False + + if self._retry_backoff_sec > 0: + await asyncio.sleep(self._retry_backoff_sec * (2**attempt)) + return False diff --git a/engine/core/ports/__init__.py b/engine/core/ports/__init__.py new file mode 100644 index 0000000..7d7c9dd --- /dev/null +++ b/engine/core/ports/__init__.py @@ -0,0 +1,17 @@ +"""Port interfaces for engine-side integration boundaries.""" + +from core.ports.backend import ( + AssistantConfigProvider, + BackendGateway, + HistoryWriter, + KnowledgeSearcher, + ToolResourceResolver, +) + +__all__ = [ + "AssistantConfigProvider", + "BackendGateway", + "HistoryWriter", + "KnowledgeSearcher", + "ToolResourceResolver", +] diff --git a/engine/core/ports/backend.py b/engine/core/ports/backend.py new file mode 100644 index 0000000..227c743 --- /dev/null +++ b/engine/core/ports/backend.py @@ -0,0 +1,84 @@ +"""Backend integration ports. + +These interfaces define the boundary between engine runtime logic and +backend-side capabilities (config lookup, history persistence, retrieval, +and tool resource discovery). +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Protocol + + +class AssistantConfigProvider(Protocol): + """Port for loading trusted assistant runtime configuration.""" + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + """Fetch assistant configuration payload.""" + + +class HistoryWriter(Protocol): + """Port for persisting call and transcript history.""" + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + """Create a call record and return backend call ID.""" + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + """Append one transcript turn segment.""" + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + """Finalize a call record.""" + + +class KnowledgeSearcher(Protocol): + """Port for RAG / knowledge retrieval operations.""" + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + """Search a knowledge source and return ranked snippets.""" + + +class ToolResourceResolver(Protocol): + """Port for resolving tool metadata/configuration.""" + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + """Fetch tool resource configuration.""" + + +class BackendGateway( + AssistantConfigProvider, + HistoryWriter, + KnowledgeSearcher, + ToolResourceResolver, + Protocol, +): + """Composite backend gateway interface used by engine services.""" + diff --git a/engine/core/session.py b/engine/core/session.py index 3f8f18d..4f19eba 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -1,22 +1,19 @@ """Session management for active calls.""" import asyncio -import uuid +import hashlib import json -import time import re +import time from enum import Enum from typing import Optional, Dict, Any, List from loguru import logger -from app.backend_client import ( - create_history_call_record, - add_history_transcript, - finalize_history_call_record, -) +from app.backend_adapters import build_backend_adapter_from_settings from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline from core.conversation import ConversationTurn +from core.history_bridge import SessionHistoryBridge from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef from app.config import settings from services.base import LLMMessage @@ -49,7 +46,39 @@ class Session: Uses full duplex voice conversation pipeline. """ - def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): + TRACK_AUDIO_IN = "audio_in" + TRACK_AUDIO_OUT = "audio_out" + TRACK_CONTROL = "control" + AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms + _CLIENT_METADATA_OVERRIDES = { + "firstTurnMode", + "greeting", + "generatedOpenerEnabled", + "systemPrompt", + "output", + "bargeIn", + "knowledge", + "knowledgeBaseId", + "history", + "userId", + "assistantId", + "source", + } + _CLIENT_METADATA_ID_KEYS = { + "appId", + "app_id", + "channel", + "configVersionId", + "config_version_id", + } + + def __init__( + self, + session_id: str, + transport: BaseTransport, + use_duplex: bool = None, + backend_gateway: Optional[Any] = None, + ): """ Initialize session. @@ -61,12 +90,23 @@ class Session: self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled + self._backend_gateway = backend_gateway or build_backend_adapter_from_settings() + self._history_bridge = SessionHistoryBridge( + history_writer=self._backend_gateway, + enabled=settings.history_enabled, + queue_max_size=settings.history_queue_max_size, + retry_max_attempts=settings.history_retry_max_attempts, + retry_backoff_sec=settings.history_retry_backoff_sec, + finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec, + ) self.pipeline = DuplexPipeline( transport=transport, session_id=session_id, system_prompt=settings.duplex_system_prompt, - greeting=settings.duplex_greeting + greeting=settings.duplex_greeting, + knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None), + tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None), ) # Session state @@ -78,17 +118,15 @@ class Session: self.authenticated: bool = False # Track IDs - self.current_track_id: Optional[str] = str(uuid.uuid4()) - self._history_call_id: Optional[str] = None - self._history_turn_index: int = 0 - self._history_call_started_mono: Optional[float] = None - self._history_finalized: bool = False + self.current_track_id: str = self.TRACK_CONTROL + self._event_seq: int = 0 self._cleanup_lock = asyncio.Lock() self._cleaned_up = False self.workflow_runner: Optional[WorkflowRunner] = None self._workflow_last_user_text: str = "" self._workflow_initial_node: Optional[WorkflowNodeDef] = None + self.pipeline.set_event_sequence_provider(self._next_event_seq) self.pipeline.conversation.on_turn_complete(self._on_turn_complete) logger.info(f"Session {self.id} created (duplex={self.use_duplex})") @@ -129,13 +167,47 @@ class Session: "client", "Audio received before session.start", "protocol.order", + stage="protocol", + retryable=False, ) return try: - await self.pipeline.process_audio(audio_bytes) + if not audio_bytes: + return + if len(audio_bytes) % 2 != 0: + await self._send_error( + "client", + "Invalid PCM payload: odd number of bytes", + "audio.invalid_pcm", + stage="audio", + retryable=False, + ) + return + + frame_bytes = self.AUDIO_FRAME_BYTES + if len(audio_bytes) % frame_bytes != 0: + await self._send_error( + "client", + f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)", + "audio.frame_size_mismatch", + stage="audio", + retryable=False, + ) + return + + for i in range(0, len(audio_bytes), frame_bytes): + frame = audio_bytes[i : i + frame_bytes] + await self.pipeline.process_audio(frame) except Exception as e: logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) + await self._send_error( + "server", + f"Audio processing failed: {e}", + "audio.processing_failed", + stage="audio", + retryable=True, + ) async def _handle_v1_message(self, message: Any) -> None: """Route validated WS v1 message to handlers.""" @@ -176,7 +248,7 @@ class Session: else: await self.pipeline.interrupt() elif isinstance(message, ToolCallResultsMessage): - await self.pipeline.handle_tool_call_results(message.results) + await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results]) elif isinstance(message, SessionStopMessage): await self._handle_session_stop(message.reason) else: @@ -198,9 +270,9 @@ class Session: self.ws_state = WsSessionState.STOPPED return - auth_payload = message.auth or {} - api_key = auth_payload.get("apiKey") - jwt = auth_payload.get("jwt") + auth_payload = message.auth + api_key = auth_payload.apiKey if auth_payload else None + jwt = auth_payload.jwt if auth_payload else None if settings.ws_api_key: if api_key != settings.ws_api_key: @@ -217,10 +289,9 @@ class Session: self.authenticated = True self.protocol_version = message.version self.ws_state = WsSessionState.WAIT_START - await self.transport.send_event( + await self._send_event( ev( "hello.ack", - sessionId=self.id, version=self.protocol_version, ) ) @@ -231,8 +302,12 @@ class Session: await self._send_error("client", "Duplicate session.start", "protocol.order") return - metadata = message.metadata or {} - metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata)) + raw_metadata = message.metadata or {} + workflow_runtime = self._bootstrap_workflow(raw_metadata) + server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime) + client_runtime = self._sanitize_client_metadata(raw_metadata) + metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime)) + metadata = self._merge_runtime_metadata(metadata, client_runtime) # Create history call record early so later turn callbacks can append transcripts. await self._start_history_bridge(metadata) @@ -248,28 +323,37 @@ class Session: self.state = "accepted" self.ws_state = WsSessionState.ACTIVE - await self.transport.send_event( + await self._send_event( ev( "session.started", - sessionId=self.id, trackId=self.current_track_id, - audio=message.audio or {}, + tracks={ + "audio_in": self.TRACK_AUDIO_IN, + "audio_out": self.TRACK_AUDIO_OUT, + "control": self.TRACK_CONTROL, + }, + audio=message.audio.model_dump() if message.audio else {}, + ) + ) + await self._send_event( + ev( + "config.resolved", + trackId=self.TRACK_CONTROL, + config=self._build_config_resolved(metadata), ) ) if self.workflow_runner and self._workflow_initial_node: - await self.transport.send_event( + await self._send_event( ev( "workflow.started", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, workflowName=self.workflow_runner.name, nodeId=self._workflow_initial_node.id, ) ) - await self.transport.send_event( + await self._send_event( ev( "workflow.node.entered", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=self._workflow_initial_node.id, nodeName=self._workflow_initial_node.name, @@ -285,17 +369,24 @@ class Session: stop_reason = reason or "client_requested" self.state = "hungup" self.ws_state = WsSessionState.STOPPED - await self.transport.send_event( + await self._send_event( ev( "session.stopped", - sessionId=self.id, reason=stop_reason, ) ) await self._finalize_history(status="connected") await self.transport.close() - async def _send_error(self, sender: str, error_message: str, code: str) -> None: + async def _send_error( + self, + sender: str, + error_message: str, + code: str, + stage: Optional[str] = None, + retryable: Optional[bool] = None, + track_id: Optional[str] = None, + ) -> None: """ Send error event to client. @@ -304,13 +395,26 @@ class Session: error_message: Error message code: Machine-readable error code """ - await self.transport.send_event( + resolved_stage = stage or self._infer_error_stage(code) + resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"}) + resolved_track_id = track_id or self._error_track_id(resolved_stage, code) + await self._send_event( ev( "error", sender=sender, code=code, message=error_message, - trackId=self.current_track_id, + stage=resolved_stage, + retryable=resolved_retryable, + trackId=resolved_track_id, + data={ + "error": { + "stage": resolved_stage, + "code": code, + "message": error_message, + "retryable": resolved_retryable, + } + }, ) ) @@ -329,11 +433,12 @@ class Session: logger.info(f"Session {self.id} cleaning up") await self._finalize_history(status="connected") await self.pipeline.cleanup() + await self._history_bridge.shutdown() await self.transport.close() async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None: """Initialize backend history call record for this session.""" - if self._history_call_id: + if self._history_bridge.call_id: return history_meta: Dict[str, Any] = {} @@ -349,7 +454,7 @@ class Session: assistant_id = history_meta.get("assistantId", metadata.get("assistantId")) source = str(history_meta.get("source", metadata.get("source", "debug"))) - call_id = await create_history_call_record( + call_id = await self._history_bridge.start_call( user_id=user_id, assistant_id=str(assistant_id) if assistant_id else None, source=source, @@ -357,10 +462,6 @@ class Session: if not call_id: return - self._history_call_id = call_id - self._history_call_started_mono = time.monotonic() - self._history_turn_index = 0 - self._history_finalized = False logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})") async def _on_turn_complete(self, turn: ConversationTurn) -> None: @@ -372,48 +473,11 @@ class Session: elif role == "assistant": await self._maybe_advance_workflow(turn.text.strip()) - if not self._history_call_id: - return - if not turn.text or not turn.text.strip(): - return - - role = (turn.role or "").lower() - speaker = "human" if role == "user" else "ai" - - end_ms = 0 - if self._history_call_started_mono is not None: - end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000)) - estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80)) - start_ms = max(0, end_ms - estimated_duration_ms) - - turn_index = self._history_turn_index - await add_history_transcript( - call_id=self._history_call_id, - turn_index=turn_index, - speaker=speaker, - content=turn.text.strip(), - start_ms=start_ms, - end_ms=end_ms, - duration_ms=max(1, end_ms - start_ms), - ) - self._history_turn_index += 1 + self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "") async def _finalize_history(self, status: str) -> None: """Finalize history call record once.""" - if not self._history_call_id or self._history_finalized: - return - - duration_seconds = 0 - if self._history_call_started_mono is not None: - duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono)) - - ok = await finalize_history_call_record( - call_id=self._history_call_id, - status=status, - duration_seconds=duration_seconds, - ) - if ok: - self._history_finalized = True + await self._history_bridge.finalize(status=status) def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Parse workflow payload and return initial runtime overrides.""" @@ -483,10 +547,9 @@ class Session: node = transition.node edge = transition.edge - await self.transport.send_event( + await self._send_event( ev( "workflow.edge.taken", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, edgeId=edge.id, fromNodeId=edge.from_node_id, @@ -494,10 +557,9 @@ class Session: reason=reason, ) ) - await self.transport.send_event( + await self._send_event( ev( "workflow.node.entered", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, nodeName=node.name, @@ -510,10 +572,9 @@ class Session: self.pipeline.apply_runtime_overrides(node_runtime) if node.node_type == "tool": - await self.transport.send_event( + await self._send_event( ev( "workflow.tool.requested", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, tool=node.tool or {}, @@ -522,10 +583,9 @@ class Session: return if node.node_type == "human_transfer": - await self.transport.send_event( + await self._send_event( ev( "workflow.human_transfer", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) @@ -534,16 +594,77 @@ class Session: return if node.node_type == "end": - await self.transport.send_event( + await self._send_event( ev( "workflow.ended", - sessionId=self.id, workflowId=self.workflow_runner.workflow_id, nodeId=node.id, ) ) await self._handle_session_stop("workflow_end") + def _next_event_seq(self) -> int: + self._event_seq += 1 + return self._event_seq + + def _event_source(self, event_type: str) -> str: + if event_type.startswith("workflow."): + return "system" + if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat": + return "system" + if event_type == "error": + return "system" + return "system" + + def _infer_error_stage(self, code: str) -> str: + normalized = str(code or "").strip().lower() + if normalized.startswith("audio."): + return "audio" + if normalized.startswith("tool."): + return "tool" + if normalized.startswith("asr."): + return "asr" + if normalized.startswith("llm."): + return "llm" + if normalized.startswith("tts."): + return "tts" + return "protocol" + + def _error_track_id(self, stage: str, code: str) -> str: + if stage in {"audio", "asr"}: + return self.TRACK_AUDIO_IN + if stage in {"llm", "tts", "tool"}: + return self.TRACK_AUDIO_OUT + if str(code or "").strip().lower().startswith("auth."): + return self.TRACK_CONTROL + return self.TRACK_CONTROL + + def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]: + event_type = str(event.get("type") or "") + source = str(event.get("source") or self._event_source(event_type)) + track_id = event.get("trackId") or self.TRACK_CONTROL + + data = event.get("data") + if not isinstance(data, dict): + data = {} + for k, v in event.items(): + if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}: + continue + data.setdefault(k, v) + + event["sessionId"] = self.id + event["seq"] = self._next_event_seq() + event["source"] = source + event["trackId"] = track_id + event["data"] = data + return event + + async def _send_event(self, event: Dict[str, Any]) -> None: + await self.transport.send_event(self._envelope_event(event)) + + async def send_heartbeat(self) -> None: + await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL)) + async def _workflow_llm_route( self, node: WorkflowNodeDef, @@ -629,6 +750,137 @@ class Session: merged[key] = value return merged + async def _load_server_runtime_metadata( + self, + client_metadata: Dict[str, Any], + workflow_runtime: Dict[str, Any], + ) -> Dict[str, Any]: + """Load trusted runtime metadata from backend assistant config.""" + assistant_id = ( + workflow_runtime.get("assistantId") + or client_metadata.get("assistantId") + or client_metadata.get("appId") + or client_metadata.get("app_id") + ) + if assistant_id is None: + return {} + + provider = getattr(self._backend_gateway, "fetch_assistant_config", None) + if not callable(provider): + return {} + + payload = await provider(str(assistant_id).strip()) + if not isinstance(payload, dict): + return {} + + assistant_cfg: Dict[str, Any] = {} + session_start_cfg = payload.get("sessionStartMetadata") + if isinstance(session_start_cfg, dict): + assistant_cfg.update(session_start_cfg) + if isinstance(payload.get("assistant"), dict): + assistant_cfg.update(payload.get("assistant")) + elif not assistant_cfg: + assistant_cfg = payload + + if not isinstance(assistant_cfg, dict): + return {} + + runtime: Dict[str, Any] = {} + passthrough_keys = { + "firstTurnMode", + "generatedOpenerEnabled", + "output", + "bargeIn", + "knowledgeBaseId", + "knowledge", + "history", + "userId", + "source", + "tools", + "services", + "configVersionId", + "config_version_id", + } + for key in passthrough_keys: + if key in assistant_cfg: + runtime[key] = assistant_cfg[key] + + if assistant_cfg.get("systemPrompt") is not None: + runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "") + elif assistant_cfg.get("prompt") is not None: + runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "") + + if assistant_cfg.get("greeting") is not None: + runtime["greeting"] = assistant_cfg.get("greeting") + elif assistant_cfg.get("opener") is not None: + runtime["greeting"] = assistant_cfg.get("opener") + + resolved_assistant_id = ( + assistant_cfg.get("assistantId") + or payload.get("assistantId") + or assistant_id + ) + runtime["assistantId"] = str(resolved_assistant_id) + + if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None: + runtime["configVersionId"] = payload.get("configVersionId") + if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None: + runtime["configVersionId"] = payload.get("config_version_id") + + if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None: + runtime["configVersionId"] = runtime.get("config_version_id") + return runtime + + def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize untrusted metadata sources. + + This keeps only a small override whitelist and stable config ID fields. + """ + if not isinstance(metadata, dict): + return {} + + sanitized: Dict[str, Any] = {} + for key in self._CLIENT_METADATA_ID_KEYS: + if key in metadata: + sanitized[key] = metadata[key] + for key in self._CLIENT_METADATA_OVERRIDES: + if key in metadata: + sanitized[key] = metadata[key] + + return sanitized + + def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Apply client metadata whitelist and remove forbidden secrets.""" + sanitized = self._sanitize_untrusted_runtime_metadata(metadata) + if isinstance(metadata.get("services"), dict): + logger.warning( + "Session {} provided metadata.services from client; client-side service config is ignored", + self.id, + ) + return sanitized + + def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Build public resolved config payload (secrets removed).""" + system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "") + prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None + runtime = self.pipeline.resolved_runtime_config() + + return { + "appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"), + "channel": metadata.get("channel"), + "configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"), + "prompt": {"sha256": prompt_hash}, + "output": runtime.get("output", {}), + "services": runtime.get("services", {}), + "tools": runtime.get("tools", {}), + "tracks": { + "audio_in": self.TRACK_AUDIO_IN, + "audio_out": self.TRACK_AUDIO_OUT, + "control": self.TRACK_CONTROL, + }, + } + def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: """Best-effort extraction of a JSON object from freeform text.""" try: diff --git a/engine/core/tool_executor.py b/engine/core/tool_executor.py index 407e199..4505436 100644 --- a/engine/core/tool_executor.py +++ b/engine/core/tool_executor.py @@ -4,11 +4,13 @@ import asyncio import ast import operator from datetime import datetime -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict, Optional import aiohttp -from app.backend_client import fetch_tool_resource +from app.backend_adapters import build_backend_adapter_from_settings + +ToolResourceFetcher = Callable[[str], Awaitable[Optional[Dict[str, Any]]]] _BIN_OPS = { ast.Add: operator.add, @@ -170,11 +172,21 @@ def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]: return {} -async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: +async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]: + """Default tool resource resolver via backend adapter.""" + adapter = build_backend_adapter_from_settings() + return await adapter.fetch_tool_resource(tool_id) + + +async def execute_server_tool( + tool_call: Dict[str, Any], + tool_resource_fetcher: Optional[ToolResourceFetcher] = None, +) -> Dict[str, Any]: """Execute a server-side tool and return normalized result payload.""" call_id = str(tool_call.get("id") or "").strip() tool_name = _extract_tool_name(tool_call) args = _extract_tool_args(tool_call) + resource_fetcher = tool_resource_fetcher or fetch_tool_resource if tool_name == "calculator": expression = str(args.get("expression") or "").strip() @@ -257,7 +269,7 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: } if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}: - resource = await fetch_tool_resource(tool_name) + resource = await resource_fetcher(tool_name) if resource and str(resource.get("category") or "") == "query": method = str(resource.get("http_method") or "GET").strip().upper() if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}: diff --git a/engine/docs/backend_integration.md b/engine/docs/backend_integration.md new file mode 100644 index 0000000..1f5d14d --- /dev/null +++ b/engine/docs/backend_integration.md @@ -0,0 +1,47 @@ +# Backend Integration and History Bridge + +This engine uses adapter-based backend integration so core runtime logic can run +with or without an external backend service. + +## Runtime Modes + +Configure with environment variables: + +- `BACKEND_MODE=auto|http|disabled` +- `BACKEND_URL` +- `BACKEND_TIMEOUT_SEC` +- `HISTORY_ENABLED=true|false` + +Mode behavior: + +- `auto`: use HTTP backend adapter only when `BACKEND_URL` is set. +- `http`: force HTTP backend adapter (falls back to null adapter when URL is missing). +- `disabled`: force null adapter and run engine-only. + +## Architecture + +- Ports: `core/ports/backend.py` +- Adapters: `app/backend_adapters.py` +- Compatibility wrappers: `app/backend_client.py` + +`Session` and `DuplexPipeline` receive backend capabilities via injected adapter +methods instead of hard-coding backend client imports. + +## Async History Writes + +Session history persistence is handled by `core/history_bridge.py`. + +Design: + +- transcript writes are queued with `put_nowait` (non-blocking turn path) +- background worker drains queue +- failed writes retry with exponential backoff +- finalize waits briefly for queue drain before sending call finalize +- finalize is idempotent + +Related settings: + +- `HISTORY_QUEUE_MAX_SIZE` +- `HISTORY_RETRY_MAX_ATTEMPTS` +- `HISTORY_RETRY_BACKOFF_SEC` +- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC` diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index 9db0900..1d35e98 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -2,6 +2,11 @@ This document defines the public WebSocket protocol for the `/ws` endpoint. +Validation policy: +- WS v1 JSON control messages are validated strictly. +- Unknown top-level fields are rejected for all defined client message types. +- `hello.version` is fixed to `"v1"`. + ## Transport - A single WebSocket connection carries: @@ -52,43 +57,26 @@ Rules: "channels": 1 }, "metadata": { + "appId": "assistant_123", + "channel": "web", + "configVersionId": "cfg_20260217_01", "client": "web-debug", "output": { "mode": "audio" }, "systemPrompt": "You are concise.", - "greeting": "Hi, how can I help?", - "services": { - "llm": { - "provider": "openai", - "model": "gpt-4o-mini", - "apiKey": "sk-...", - "baseUrl": "https://api.openai.com/v1" - }, - "asr": { - "provider": "openai_compatible", - "model": "FunAudioLLM/SenseVoiceSmall", - "apiKey": "sf-...", - "interimIntervalMs": 500, - "minAudioMs": 300 - }, - "tts": { - "enabled": true, - "provider": "openai_compatible", - "model": "FunAudioLLM/CosyVoice2-0.5B", - "apiKey": "sf-...", - "voice": "anna", - "speed": 1.0 - } - } + "greeting": "Hi, how can I help?" } } ``` -`metadata.services` is optional. If omitted, server defaults to environment configuration. +Rules: +- Client-side `metadata.services` is ignored. +- Service config (including secrets) is resolved server-side (env/backend). +- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints). Text-only mode: -- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`. +- Set `metadata.output.mode = "text"`. - In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. ### `input.text` @@ -121,6 +109,7 @@ Text-only mode: ### `tool_call.results` Client tool execution results returned to server. +Only needed when `assistant.tool_call.executor == "client"` (default execution is server-side). ```json { @@ -138,21 +127,36 @@ Client tool execution results returned to server. ## Server -> Client Events -All server events include: +All server events include an envelope: ```json { "type": "event.name", - "timestamp": 1730000000000 + "timestamp": 1730000000000, + "sessionId": "sess_xxx", + "seq": 42, + "source": "asr", + "trackId": "audio_in", + "data": {} } ``` +Envelope notes: +- `seq` is monotonically increasing within one session (for replay/resume). +- `source` is one of: `asr | llm | tts | tool | system | client | server`. + - For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side. +- `data` is structured payload; legacy top-level fields are kept for compatibility. + Common events: - `hello.ack` - Fields: `sessionId`, `version` - `session.started` - - Fields: `sessionId`, `trackId`, `audio` + - Fields: `sessionId`, `trackId`, `tracks`, `audio` +- `config.resolved` + - Fields: `sessionId`, `trackId`, `config` + - Sent immediately after `session.started`. + - Contains effective model/voice/output/tool allowlist/prompt hash, and never includes secrets. - `session.stopped` - Fields: `sessionId`, `reason` - `heartbeat` @@ -169,9 +173,10 @@ Common events: - `assistant.response.final` - Fields: `trackId`, `text` - `assistant.tool_call` - - Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`) + - Fields: `trackId`, `tool_call`, `tool_call_id`, `tool_name`, `arguments`, `executor`, `timeout_ms` - `assistant.tool_result` - - Fields: `trackId`, `source`, `result` + - Fields: `trackId`, `source`, `result`, `tool_call_id`, `tool_name`, `ok`, `error` + - `error`: `{ code, message, retryable }` when `ok=false` - `output.audio.start` - Fields: `trackId` - `output.audio.end` @@ -182,16 +187,54 @@ Common events: - Fields: `trackId`, `latencyMs` - `error` - Fields: `sender`, `code`, `message`, `trackId` + - `trackId` convention: + - `audio_in` for `stage in {audio, asr}` + - `audio_out` for `stage in {llm, tts, tool}` + - `control` otherwise (including protocol/auth errors) + +Track IDs (MVP fixed values): +- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`) +- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`) +- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`) + +Correlation IDs (`event.data`): +- `turn_id`: one user-assistant interaction turn. +- `utterance_id`: one ASR final utterance. +- `response_id`: one assistant response generation. +- `tool_call_id`: one tool invocation. +- `tts_id`: one TTS playback segment. ## Binary Audio Frames After `session.started`, client may send binary PCM chunks continuously. -Recommended format: -- 16-bit signed little-endian PCM. -- 1 channel. -- 16000 Hz. -- 20ms frames (640 bytes) preferred. +MVP fixed format: +- 16-bit signed little-endian PCM (`pcm_s16le`) +- mono (1 channel) +- 16000 Hz +- 20ms frame = 640 bytes + +Framing rules: +- Binary audio frame unit is 640 bytes. +- A WS binary message may carry one or multiple complete 640-byte frames. +- Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly). + +TTS boundary events: +- `output.audio.start` and `output.audio.end` mark assistant playback boundaries. + +## Event Throttling + +To keep client rendering and server load stable, v1 applies/recommends: +- `transcript.delta`: merge to ~200-500ms cadence (server default: 300ms). +- `assistant.response.delta`: merge to ~50-100ms cadence (server default: 80ms). +- Metrics streams (if enabled beyond `metrics.ttfb`): emit every ~500-1000ms. + +## Error Structure + +`error` keeps legacy top-level fields (`code`, `message`) and adds structured info: +- `stage`: `protocol | asr | llm | tts | tool | audio` +- `retryable`: boolean +- `data.error`: `{ stage, code, message, retryable }` ## Compatibility diff --git a/engine/docs/ws_v1_schema_zh.md b/engine/docs/ws_v1_schema_zh.md new file mode 100644 index 0000000..25c5ad9 --- /dev/null +++ b/engine/docs/ws_v1_schema_zh.md @@ -0,0 +1,520 @@ +# WS v1 协议完整说明(中文) + +本文档描述 `/ws` 端点的 WebSocket v1 协议,覆盖: +- 客户端输入(JSON 文本消息 + 二进制音频); +- 服务端输出(JSON 事件 + 二进制音频); +- 每个参数的类型、约束、含义与使用方式; +- 握手顺序、状态机、错误语义与实现细节。 + +实现对照来源: +- `models/ws_v1.py` +- `core/session.py` +- `core/duplex_pipeline.py` +- `app/main.py` + +--- + +## 1. 传输与基础规则 + +- 连接地址:`ws:///ws` +- 单连接双通道承载: + - 文本帧:JSON 控制消息(严格校验 schema) + - 二进制帧:原始 PCM 音频 +- JSON 校验策略: + - 所有已定义客户端消息都 `extra="forbid"`,即不允许未声明字段; + - `hello.version` 固定必须是 `"v1"`; + - 缺失 `type` 或未知 `type` 会返回协议错误。 + +--- + +## 2. 状态机与消息顺序 + +### 2.1 服务端状态 + +- `WAIT_HELLO`:等待 `hello` +- `WAIT_START`:已通过握手,等待 `session.start` +- `ACTIVE`:会话运行中,可收发文本/音频 +- `STOPPED`:会话结束 + +### 2.2 正确顺序 + +1. 客户端发送 `hello` +2. 服务端返回 `hello.ack` +3. 客户端发送 `session.start` +4. 服务端返回 `session.started` +5. 客户端可持续发送: + - 二进制音频 + - `input.text`(可选) + - `response.cancel`(可选) + - `tool_call.results`(可选) +6. 客户端发送 `session.stop` 或直接断开连接 + +顺序错误会返回 `error`,`code = "protocol.order"`。 + +--- + +## 3. 客户端 -> 服务端消息(输入) + +## 3.1 `hello` + +示例: + +```json +{ + "type": "hello", + "version": "v1", + "auth": { + "apiKey": "optional-api-key", + "jwt": "optional-jwt" + } +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 | +| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 | +| `auth` | object \| null | 否 | 仅允许 `apiKey`、`jwt` | 认证载荷 | 认证策略由服务端配置决定 | +| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 | +| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 | + +认证行为: +- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。 +- 若 `WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY`:`auth.apiKey` 或 `auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。 + +## 3.2 `session.start` + +示例: + +```json +{ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": 16000, + "channels": 1 + }, + "metadata": { + "appId": "assistant_123", + "channel": "web", + "configVersionId": "cfg_20260217_01", + "client": "web-debug", + "output": { + "mode": "audio" + }, + "systemPrompt": "你是简洁助手", + "greeting": "你好,我能帮你什么?" + } +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"session.start"` | 启动会话 | 握手后第二阶段消息 | +| `audio` | object \| null | 否 | 仅支持固定值 | 音频格式描述 | 仅用于声明;MVP 实际只接受固定 PCM | +| `audio.encoding` | string | 否 | 固定 `"pcm_s16le"` | 编码格式 | 非该值会在模型校验层报错 | +| `audio.sample_rate_hz` | number | 否 | 固定 `16000` | 采样率 | 16kHz | +| `audio.channels` | number | 否 | 固定 `1` | 声道数 | 单声道 | +| `metadata` | object \| null | 否 | 任意对象(会被白名单过滤) | 运行时配置 | 用于 app/channel/提示词/输出模式等覆盖 | + +`metadata` 白名单策略(关键): +- 允许透传的标识字段(ID 类): + - `appId` / `app_id` + - `channel` + - `configVersionId` / `config_version_id` +- 允许透传的覆盖字段: + - `firstTurnMode` + - `greeting` + - `generatedOpenerEnabled` + - `systemPrompt` + - `output` + - `bargeIn` + - `knowledge` + - `knowledgeBaseId` + - `history` + - `userId` + - `assistantId` + - `source` +- 客户端传入 `metadata.services` 会被忽略(服务端会记录 warning),服务配置由后端/环境变量决定。 + +`output.mode` 用法: +- `"audio"`(默认语音输出) +- `"text"`(纯文本输出) + - 纯文本模式下仍会收到 `assistant.response.delta/final`; + - 不会收到 TTS 音频帧与 `output.audio.start/end`。 + +## 3.3 `input.text` + +示例: + +```json +{ + "type": "input.text", + "text": "你能做什么?" +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"input.text"` | 文本输入 | 跳过 ASR,直接触发 LLM 回答 | +| `text` | string | 是 | 非空字符串为佳 | 用户文本 | 用于文本聊天或调试 | + +## 3.4 `response.cancel` + +示例: + +```json +{ + "type": "response.cancel", + "graceful": false +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 默认值 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 | +| `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 | + +## 3.5 `tool_call.results` + +仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。 + +示例: + +```json +{ + "type": "tool_call.results", + "results": [ + { + "tool_call_id": "call_abc123", + "name": "weather", + "output": { "temp_c": 21, "condition": "sunny" }, + "status": { "code": 200, "message": "ok" } + } + ] +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"tool_call.results"` | 工具执行回传 | 客户端工具结果上送 | +| `results` | array | 否 | 默认为空数组 | 多个工具结果 | 可批量回传 | +| `results[].tool_call_id` | string | 是 | 任意字符串 | 工具调用ID | 必须与 `assistant.tool_call.tool_call_id` 对应 | +| `results[].name` | string | 是 | 任意字符串 | 工具名 | 建议与请求一致 | +| `results[].output` | any | 否 | 任意 JSON | 工具输出 | 供模型后续组织回答 | +| `results[].status` | object | 是 | 包含 `code`、`message` | 执行状态 | 用于判定成功/失败 | +| `results[].status.code` | number | 是 | HTTP 风格状态码 | 状态码 | `200-299` 判定成功 | +| `results[].status.message` | string | 是 | 任意字符串 | 状态描述 | 例如 `"ok"` / `"timeout"` | + +处理规则: +- 未请求过的 `tool_call_id` 会被忽略(防止伪造/串话); +- 重复回传会被忽略; +- 超时未回传会由服务端合成超时结果(`504`)。 + +## 3.6 `session.stop` + +示例: + +```json +{ + "type": "session.stop", + "reason": "client_disconnect" +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"session.stop"` | 结束会话 | 正常结束推荐发送 | +| `reason` | string \| null | 否 | 任意字符串 | 结束原因 | 服务端会回传到 `session.stopped.reason` | + +--- + +## 4. 二进制音频输入(客户端 -> 服务端) + +在 `session.started` 之后可持续发送二进制音频。 + +固定格式(MVP): +- 编码:`pcm_s16le` +- 采样率:`16000` +- 声道:`1` +- 帧长:20ms = `640 bytes` + +分包规则: +- 单个 WebSocket 二进制消息可包含 1 帧或多帧; +- 长度必须是 `640` 的整数倍; +- 不是 `640` 倍数会触发 `audio.frame_size_mismatch`,该消息整包丢弃; +- 奇数字节长度会触发 `audio.invalid_pcm`。 + +--- + +## 5. 服务端 -> 客户端事件(输出) + +所有 JSON 事件都包含统一包络字段。 + +## 5.1 统一包络(Envelope) + +```json +{ + "type": "event.name", + "timestamp": 1730000000000, + "sessionId": "sess_xxx", + "seq": 42, + "source": "asr", + "trackId": "audio_in", + "data": {} +} +``` + +字段说明: + +| 字段 | 类型 | 含义 | 使用说明 | +|---|---|---|---| +| `type` | string | 事件类型 | 见下方事件清单 | +| `timestamp` | number | 事件时间戳(毫秒) | 由 `ev()` 生成 | +| `sessionId` | string | 会话ID | 同一连接固定 | +| `seq` | number | 单会话递增序号 | 可用于重放、去重、排序 | +| `source` | string | 事件来源 | 常见:`asr`/`llm`/`tts`/`tool`/`system`/`client`/`server` | +| `trackId` | string | 事件轨道 | 常用:`audio_in`/`audio_out`/`control` | +| `data` | object | 结构化数据 | 顶层业务字段会镜像进 `data` 以兼容旧客户端 | + +关联ID(在 `data` 内自动注入,存在时): +- `turn_id`:一次用户-助手对话轮次 +- `utterance_id`:一次用户语音话语 +- `response_id`:一次助手生成响应 +- `tool_call_id`:一次工具调用 +- `tts_id`:一次 TTS 播放段 + +## 5.2 事件类型与参数 + +### 5.2.1 会话与控制类 + +1. `hello.ack` +- 关键字段:`version` +- 含义:握手成功,应紧接着发送 `session.start` + +2. `session.started` +- 关键字段: + - `trackId` + - `tracks.audio_in` + - `tracks.audio_out` + - `tracks.control` + - `audio`(回显客户端声明的音频元信息) +- 含义:会话进入 ACTIVE,可发音频/文本 + +3. `config.resolved` +- 关键字段: + - `config.appId` + - `config.channel` + - `config.configVersionId` + - `config.prompt.sha256` + - `config.output` + - `config.services`(去密钥后的有效服务配置) + - `config.tools.allowlist` + - `config.tracks` +- 含义:服务端最终生效配置快照,便于前端展示与排错 + +4. `heartbeat` +- 关键字段:无业务字段(仅 envelope) +- 含义:保活心跳 +- 默认间隔:`heartbeat_interval_sec`(默认 50s) + +5. `session.stopped` +- 关键字段:`reason` +- 含义:会话结束确认 + +6. `error` +- 关键字段: + - `sender` + - `code` + - `message` + - `stage` + - `retryable` + - `trackId` + - `data.error`(结构化错误镜像) +- 含义:统一错误事件 + +### 5.2.2 识别与输入侧(ASR/VAD) + +1. `input.speech_started` +- 字段:`probability` +- 含义:检测到语音开始 + +2. `input.speech_stopped` +- 字段:`probability` +- 含义:检测到语音结束 + +3. `transcript.delta` +- 字段:`text` +- 含义:ASR 增量识别文本(节流发送) + +4. `transcript.final` +- 字段:`text` +- 含义:ASR 最终识别文本 + +### 5.2.3 输出侧(LLM/TTS/Tool) + +1. `assistant.response.delta` +- 字段:`text` +- 含义:助手增量文本输出(节流发送) + +2. `assistant.response.final` +- 字段:`text` +- 含义:助手完整文本输出 + +3. `assistant.tool_call` +- 字段: + - `tool_call_id` + - `tool_name` + - `arguments`(对象) + - `executor`(`client` 或 `server`) + - `timeout_ms` + - `tool_call`(完整工具调用对象) +- 含义:通知客户端发生工具调用(用于可视化或客户端执行) + +4. `assistant.tool_result` +- 字段: + - `source`(`client` 或 `server`) + - `tool_call_id` + - `tool_name` + - `ok`(boolean) + - `error`(失败时 `{code,message,retryable}`) + - `result`(原始结果对象) +- 含义:工具调用结果回执 + +5. `output.audio.start` +- 含义:TTS 音频输出开始边界 + +6. `output.audio.end` +- 含义:TTS 音频输出结束边界 + +7. `response.interrupted` +- 含义:当前回答被打断(barge-in 或 cancel) + +8. `metrics.ttfb` +- 字段:`latencyMs` +- 含义:首包音频时延(TTFB) + +### 5.2.4 工作流扩展事件(可选) + +若 `metadata.workflow` 生效,会额外出现: +- `workflow.started` +- `workflow.node.entered` +- `workflow.edge.taken` +- `workflow.tool.requested` +- `workflow.human_transfer` +- `workflow.ended` + +这些事件用于外部可视化工作流状态,不影响基础语音会话协议。 + +--- + +## 6. 服务端二进制音频输出(服务端 -> 客户端) + +- 音频为 PCM 二进制帧; +- 发送单位对齐到 `640 bytes`(不足会补零后发送); +- 前端通常结合 `output.audio.start/end` 做播放边界控制; +- 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。 + +--- + +## 7. 错误模型与常见错误码 + +统一结构(`error` 事件): + +```json +{ + "type": "error", + "sender": "client", + "code": "protocol.invalid_message", + "message": "Invalid message: ...", + "stage": "protocol", + "retryable": false, + "trackId": "control", + "data": { + "error": { + "stage": "protocol", + "code": "protocol.invalid_message", + "message": "Invalid message: ...", + "retryable": false + } + } +} +``` + +字段语义: +- `sender`:错误来源角色(如 `client` / `server` / `auth`) +- `code`:机器可读错误码 +- `message`:人类可读描述 +- `stage`:阶段(`protocol|audio|asr|llm|tts|tool`) +- `retryable`:是否建议重试 +- `trackId`:错误归属轨道 + +常见错误码: +- `protocol.invalid_json` +- `protocol.invalid_message` +- `protocol.order` +- `protocol.version_unsupported` +- `protocol.unsupported` +- `auth.invalid_api_key` +- `auth.required` +- `audio.invalid_pcm` +- `audio.frame_size_mismatch` +- `audio.processing_failed` +- `server.internal` + +--- + +## 8. 心跳与超时 + +服务端后台任务逻辑: +- 每隔约 5 秒检查一次连接; +- 超过 `inactivity_timeout_sec`(默认 60 秒)未收到任何客户端消息则关闭会话; +- 每隔 `heartbeat_interval_sec`(默认 50 秒)发送一次 `heartbeat`。 + +客户端建议: +- 持续上行音频或定期发送轻量文本消息,避免被判定闲置; +- 用 `heartbeat` + `seq` 检测连接活性和事件乱序。 + +--- + +## 9. 实战接入建议 + +1. 建连后立即发送 `hello`,收到 `hello.ack` 后再发 `session.start`。 +2. 语音输入严格按 16k/16bit/mono,并保证每个 WS 二进制消息长度是 `640*n`。 +3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。 +4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。 +5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`。 +6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`。 + +--- + +## 10. 最小完整时序示例 + +```text +Client -> hello +Server <- hello.ack +Client -> session.start +Server <- session.started +Server <- config.resolved +Client -> (binary pcm frames...) +Server <- input.speech_started / transcript.delta / transcript.final +Server <- assistant.response.delta / assistant.response.final +Server <- output.audio.start +Server <- (binary pcm frames...) +Server <- output.audio.end +Client -> session.stop +Server <- session.stopped +``` + diff --git a/engine/examples/mic_client.py b/engine/examples/mic_client.py index 509aeaa..00d403f 100644 --- a/engine/examples/mic_client.py +++ b/engine/examples/mic_client.py @@ -59,8 +59,12 @@ class MicrophoneClient: url: str, sample_rate: int = 16000, chunk_duration_ms: int = 20, + app_id: str = "assistant_demo", + channel: str = "mic_client", + config_version_id: str = "local-dev", input_device: int = None, - output_device: int = None + output_device: int = None, + track_debug: bool = False, ): """ Initialize microphone client. @@ -76,8 +80,12 @@ class MicrophoneClient: self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) + self.app_id = app_id + self.channel = channel + self.config_version_id = config_version_id self.input_device = input_device self.output_device = output_device + self.track_debug = track_debug # WebSocket connection self.ws = None @@ -106,6 +114,17 @@ class MicrophoneClient: # Verbose mode for streaming LLM responses self.verbose = False + + @staticmethod + def _event_ids_suffix(event: dict) -> str: + data = event.get("data") if isinstance(event.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = data.get(key, event.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" async def connect(self) -> None: """Connect to WebSocket server.""" @@ -114,20 +133,30 @@ class MicrophoneClient: self.running = True print("Connected!") - # Send invite command + # WS v1 handshake: hello -> session.start await self.send_command({ - "command": "invite", - "option": { - "codec": "pcm", - "sampleRate": self.sample_rate - } + "type": "hello", + "version": "v1", + }) + await self.send_command({ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": self.sample_rate, + "channels": 1, + }, + "metadata": { + "appId": self.app_id, + "channel": self.channel, + "configVersionId": self.config_version_id, + }, }) async def send_command(self, cmd: dict) -> None: """Send JSON command to server.""" if self.ws: await self.ws.send(json.dumps(cmd)) - print(f"→ Command: {cmd.get('command', 'unknown')}") + print(f"→ Command: {cmd.get('type', 'unknown')}") async def send_chat(self, text: str) -> None: """Send chat message (text input).""" @@ -136,7 +165,7 @@ class MicrophoneClient: self.first_audio_received = False await self.send_command({ - "command": "chat", + "type": "input.text", "text": text }) print(f"→ Chat: {text}") @@ -144,13 +173,14 @@ class MicrophoneClient: async def send_interrupt(self) -> None: """Send interrupt command.""" await self.send_command({ - "command": "interrupt" + "type": "response.cancel", + "graceful": False, }) async def send_hangup(self, reason: str = "User quit") -> None: """Send hangup command.""" await self.send_command({ - "command": "hangup", + "type": "session.stop", "reason": reason }) @@ -295,43 +325,48 @@ class MicrophoneClient: async def _handle_event(self, event: dict) -> None: """Handle incoming event.""" - event_type = event.get("event", "unknown") + event_type = event.get("type", event.get("event", "unknown")) + ids = self._event_ids_suffix(event) + if self.track_debug: + print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") - if event_type == "answer": - print("← Session ready!") - elif event_type == "speaking": - print("← User speech detected") - elif event_type == "silence": - print("← User silence detected") - elif event_type == "transcript": + if event_type in {"hello.ack", "session.started"}: + print(f"← Session ready!{ids}") + elif event_type == "config.resolved": + print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}") + elif event_type == "input.speech_started": + print(f"← User speech detected{ids}") + elif event_type == "input.speech_stopped": + print(f"← User silence detected{ids}") + elif event_type in {"transcript", "transcript.delta", "transcript.final"}: # Display user speech transcription text = event.get("text", "") - is_final = event.get("isFinal", False) + is_final = event_type == "transcript.final" or bool(event.get("isFinal")) if is_final: # Clear the interim line and print final print(" " * 80, end="\r") # Clear previous interim text - print(f"→ You: {text}") + print(f"→ You: {text}{ids}") else: # Interim result - show with indicator (overwrite same line) display_text = text[:60] + "..." if len(text) > 60 else text print(f" [listening] {display_text}".ljust(80), end="\r") - elif event_type == "ttfb": + elif event_type in {"ttfb", "metrics.ttfb"}: # Server-side TTFB event latency_ms = event.get("latencyMs", 0) print(f"← [TTFB] Server reported latency: {latency_ms}ms") - elif event_type == "llmResponse": + elif event_type in {"llmResponse", "assistant.response.delta", "assistant.response.final"}: # LLM text response text = event.get("text", "") - is_final = event.get("isFinal", False) + is_final = event_type == "assistant.response.final" or bool(event.get("isFinal")) if is_final: # Print final LLM response print(f"← AI: {text}") elif self.verbose: # Show streaming chunks only in verbose mode display_text = text[:60] + "..." if len(text) > 60 else text - print(f" [streaming] {display_text}") - elif event_type == "trackStart": - print("← Bot started speaking") + print(f" [streaming] {display_text}{ids}") + elif event_type in {"trackStart", "output.audio.start"}: + print(f"← Bot started speaking{ids}") # IMPORTANT: Accept audio again after trackStart self._discard_audio = False self._audio_sequence += 1 @@ -342,13 +377,13 @@ class MicrophoneClient: # Clear any old audio in buffer with self.audio_output_lock: self.audio_output_buffer = b"" - elif event_type == "trackEnd": - print("← Bot finished speaking") + elif event_type in {"trackEnd", "output.audio.end"}: + print(f"← Bot finished speaking{ids}") # Reset TTFB tracking after response completes self.request_start_time = None self.first_audio_received = False - elif event_type == "interrupt": - print("← Bot interrupted!") + elif event_type in {"interrupt", "response.interrupted"}: + print(f"← Bot interrupted!{ids}") # IMPORTANT: Discard all audio until next trackStart self._discard_audio = True # Clear audio buffer immediately @@ -357,12 +392,12 @@ class MicrophoneClient: self.audio_output_buffer = b"" print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)") elif event_type == "error": - print(f"← Error: {event.get('error')}") - elif event_type == "hangup": - print(f"← Hangup: {event.get('reason')}") + print(f"← Error: {event.get('error')}{ids}") + elif event_type in {"hangup", "session.stopped"}: + print(f"← Hangup: {event.get('reason')}{ids}") self.running = False else: - print(f"← Event: {event_type}") + print(f"← Event: {event_type}{ids}") async def interactive_mode(self) -> None: """Run interactive mode for text chat.""" @@ -573,6 +608,26 @@ async def main(): action="store_true", help="Show streaming LLM response chunks" ) + parser.add_argument( + "--app-id", + default="assistant_demo", + help="Stable app/assistant identifier for server-side config lookup" + ) + parser.add_argument( + "--channel", + default="mic_client", + help="Client channel name" + ) + parser.add_argument( + "--config-version-id", + default="local-dev", + help="Optional config version identifier" + ) + parser.add_argument( + "--track-debug", + action="store_true", + help="Print event trackId for protocol debugging" + ) args = parser.parse_args() @@ -583,8 +638,12 @@ async def main(): client = MicrophoneClient( url=args.url, sample_rate=args.sample_rate, + app_id=args.app_id, + channel=args.channel, + config_version_id=args.config_version_id, input_device=args.input_device, - output_device=args.output_device + output_device=args.output_device, + track_debug=args.track_debug, ) client.verbose = args.verbose diff --git a/engine/examples/simple_client.py b/engine/examples/simple_client.py index 4280f93..b1648bf 100644 --- a/engine/examples/simple_client.py +++ b/engine/examples/simple_client.py @@ -52,9 +52,21 @@ if not PYAUDIO_AVAILABLE and not SD_AVAILABLE: class SimpleVoiceClient: """Simple voice client with reliable audio playback.""" - def __init__(self, url: str, sample_rate: int = 16000): + def __init__( + self, + url: str, + sample_rate: int = 16000, + app_id: str = "assistant_demo", + channel: str = "simple_client", + config_version_id: str = "local-dev", + track_debug: bool = False, + ): self.url = url self.sample_rate = sample_rate + self.app_id = app_id + self.channel = channel + self.config_version_id = config_version_id + self.track_debug = track_debug self.ws = None self.running = False @@ -75,6 +87,17 @@ class SimpleVoiceClient: # Interrupt handling - discard audio until next trackStart self._discard_audio = False + + @staticmethod + def _event_ids_suffix(event: dict) -> str: + data = event.get("data") if isinstance(event.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = data.get(key, event.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" async def connect(self): """Connect to server.""" @@ -83,12 +106,25 @@ class SimpleVoiceClient: self.running = True print("Connected!") - # Send invite + # WS v1 handshake: hello -> session.start await self.ws.send(json.dumps({ - "command": "invite", - "option": {"codec": "pcm", "sampleRate": self.sample_rate} + "type": "hello", + "version": "v1", })) - print("-> invite") + await self.ws.send(json.dumps({ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": self.sample_rate, + "channels": 1, + }, + "metadata": { + "appId": self.app_id, + "channel": self.channel, + "configVersionId": self.config_version_id, + }, + })) + print("-> hello/session.start") async def send_chat(self, text: str): """Send chat message.""" @@ -96,8 +132,8 @@ class SimpleVoiceClient: self.request_start_time = time.time() self.first_audio_received = False - await self.ws.send(json.dumps({"command": "chat", "text": text})) - print(f"-> chat: {text}") + await self.ws.send(json.dumps({"type": "input.text", "text": text})) + print(f"-> input.text: {text}") def play_audio(self, audio_data: bytes): """Play audio data immediately.""" @@ -152,34 +188,39 @@ class SimpleVoiceClient: else: # JSON event event = json.loads(msg) - etype = event.get("event", "?") + etype = event.get("type", event.get("event", "?")) + ids = self._event_ids_suffix(event) + if self.track_debug: + print(f"[track-debug] event={etype} trackId={event.get('trackId')}{ids}") - if etype == "transcript": + if etype in {"transcript", "transcript.delta", "transcript.final"}: # User speech transcription text = event.get("text", "") - is_final = event.get("isFinal", False) + is_final = etype == "transcript.final" or bool(event.get("isFinal")) if is_final: - print(f"<- You said: {text}") + print(f"<- You said: {text}{ids}") else: print(f"<- [listening] {text}", end="\r") - elif etype == "ttfb": + elif etype in {"ttfb", "metrics.ttfb"}: # Server-side TTFB event latency_ms = event.get("latencyMs", 0) print(f"<- [TTFB] Server reported latency: {latency_ms}ms") - elif etype == "trackStart": + elif etype in {"trackStart", "output.audio.start"}: # New track starting - accept audio again self._discard_audio = False - print(f"<- {etype}") - elif etype == "interrupt": + print(f"<- {etype}{ids}") + elif etype in {"interrupt", "response.interrupted"}: # Interrupt - discard audio until next trackStart self._discard_audio = True - print(f"<- {etype} (discarding audio until new track)") - elif etype == "hangup": - print(f"<- {etype}") + print(f"<- {etype}{ids} (discarding audio until new track)") + elif etype in {"hangup", "session.stopped"}: + print(f"<- {etype}{ids}") self.running = False break + elif etype == "config.resolved": + print(f"<- config.resolved {event.get('config', {}).get('output', {})}{ids}") else: - print(f"<- {etype}") + print(f"<- {etype}{ids}") except asyncio.TimeoutError: continue @@ -270,6 +311,10 @@ async def main(): parser.add_argument("--text", help="Send text and play response") parser.add_argument("--list-devices", action="store_true") parser.add_argument("--sample-rate", type=int, default=16000) + parser.add_argument("--app-id", default="assistant_demo") + parser.add_argument("--channel", default="simple_client") + parser.add_argument("--config-version-id", default="local-dev") + parser.add_argument("--track-debug", action="store_true") args = parser.parse_args() @@ -277,7 +322,14 @@ async def main(): list_audio_devices() return - client = SimpleVoiceClient(args.url, args.sample_rate) + client = SimpleVoiceClient( + args.url, + args.sample_rate, + app_id=args.app_id, + channel=args.channel, + config_version_id=args.config_version_id, + track_debug=args.track_debug, + ) await client.run(args.text) diff --git a/engine/examples/test_websocket.py b/engine/examples/test_websocket.py index 0d2675d..6717834 100644 --- a/engine/examples/test_websocket.py +++ b/engine/examples/test_websocket.py @@ -36,8 +36,18 @@ def generate_sine_wave(duration_ms=1000): return audio_data -async def receive_loop(ws, ready_event: asyncio.Event): +async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False): """Listen for incoming messages from the server.""" + def event_ids_suffix(data): + payload = data.get("data") if isinstance(data.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = payload.get(key, data.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" + print("👂 Listening for server responses...") async for msg in ws: timestamp = datetime.now().strftime("%H:%M:%S") @@ -46,7 +56,10 @@ async def receive_loop(ws, ready_event: asyncio.Event): try: data = json.loads(msg.data) event_type = data.get('type', 'Unknown') - print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...") + ids = event_ids_suffix(data) + print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...") + if track_debug: + print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}") if event_type == "session.started": ready_event.set() except json.JSONDecodeError: @@ -113,7 +126,7 @@ async def send_sine_loop(ws): print("\n✅ Finished streaming test audio.") -async def run_client(url, file_path=None, use_sine=False): +async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False): """Run the WebSocket test client.""" session = aiohttp.ClientSession() try: @@ -121,7 +134,7 @@ async def run_client(url, file_path=None, use_sine=False): async with session.ws_connect(url) as ws: print("✅ Connected!") session_ready = asyncio.Event() - recv_task = asyncio.create_task(receive_loop(ws, session_ready)) + recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug)) # Send v1 hello + session.start handshake await ws.send_json({"type": "hello", "version": "v1"}) @@ -131,7 +144,12 @@ async def run_client(url, file_path=None, use_sine=False): "encoding": "pcm_s16le", "sample_rate_hz": SAMPLE_RATE, "channels": 1 - } + }, + "metadata": { + "appId": "assistant_demo", + "channel": "test_websocket", + "configVersionId": "local-dev", + }, }) print("📤 Sent v1 hello/session.start") await asyncio.wait_for(session_ready.wait(), timeout=8) @@ -168,9 +186,10 @@ if __name__ == "__main__": parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL") parser.add_argument("--file", help="Path to PCM/WAV file to stream") parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)") + parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging") args = parser.parse_args() try: - asyncio.run(run_client(args.url, args.file, args.sine)) + asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug)) except KeyboardInterrupt: print("\n👋 Client stopped.") diff --git a/engine/examples/wav_client.py b/engine/examples/wav_client.py index db638b9..5684256 100644 --- a/engine/examples/wav_client.py +++ b/engine/examples/wav_client.py @@ -57,10 +57,15 @@ class WavFileClient: url: str, input_file: str, output_file: str, + app_id: str = "assistant_demo", + channel: str = "wav_client", + config_version_id: str = "local-dev", sample_rate: int = 16000, chunk_duration_ms: int = 20, wait_time: float = 15.0, - verbose: bool = False + verbose: bool = False, + track_debug: bool = False, + tail_silence_ms: int = 800, ): """ Initialize WAV file client. @@ -77,11 +82,17 @@ class WavFileClient: self.url = url self.input_file = Path(input_file) self.output_file = Path(output_file) + self.app_id = app_id + self.channel = channel + self.config_version_id = config_version_id self.sample_rate = sample_rate self.chunk_duration_ms = chunk_duration_ms self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) self.wait_time = wait_time self.verbose = verbose + self.track_debug = track_debug + self.tail_silence_ms = max(0, int(tail_silence_ms)) + self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms # WebSocket connection self.ws = None @@ -105,6 +116,7 @@ class WavFileClient: self.track_started = False self.track_ended = False self.send_completed = False + self.session_ready = False # Events log self.events_log = [] @@ -124,6 +136,17 @@ class WavFileClient: # Replace problematic characters for console output safe_message = message.encode('ascii', errors='replace').decode('ascii') print(f"{direction} {safe_message}") + + @staticmethod + def _event_ids_suffix(event: dict) -> str: + data = event.get("data") if isinstance(event.get("data"), dict) else {} + keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id") + parts = [] + for key in keys: + value = data.get(key, event.get(key)) + if value: + parts.append(f"{key}={value}") + return f" [{' '.join(parts)}]" if parts else "" async def connect(self) -> None: """Connect to WebSocket server.""" @@ -131,26 +154,36 @@ class WavFileClient: self.ws = await websockets.connect(self.url) self.running = True self.log_event("←", "Connected!") - - # Send invite command + + # WS v1 handshake: hello -> session.start await self.send_command({ - "command": "invite", - "option": { - "codec": "pcm", - "sampleRate": self.sample_rate - } + "type": "hello", + "version": "v1", + }) + await self.send_command({ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": self.sample_rate, + "channels": 1 + }, + "metadata": { + "appId": self.app_id, + "channel": self.channel, + "configVersionId": self.config_version_id, + }, }) async def send_command(self, cmd: dict) -> None: """Send JSON command to server.""" if self.ws: await self.ws.send(json.dumps(cmd)) - self.log_event("→", f"Command: {cmd.get('command', 'unknown')}") + self.log_event("→", f"Command: {cmd.get('type', 'unknown')}") async def send_hangup(self, reason: str = "Session complete") -> None: """Send hangup command.""" await self.send_command({ - "command": "hangup", + "type": "session.stop", "reason": reason }) @@ -210,6 +243,10 @@ class WavFileClient: end_sample = min(sent_samples + chunk_size, total_samples) chunk = audio_data[sent_samples:end_sample] chunk_bytes = chunk.tobytes() + if len(chunk_bytes) % self.frame_bytes != 0: + # v1 audio framing requires 640-byte (20ms) PCM units. + pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes) + chunk_bytes += b"\x00" * pad # Send to server if self.ws: @@ -226,6 +263,16 @@ class WavFileClient: # Delay to simulate real-time streaming # Server expects audio at real-time pace for VAD/ASR to work properly await asyncio.sleep(self.chunk_duration_ms / 1000) + + # Add a short silence tail to help VAD/EOU close the final utterance. + if self.tail_silence_ms > 0 and self.ws: + tail_frames = max(1, self.tail_silence_ms // 20) + silence = b"\x00" * self.frame_bytes + for _ in range(tail_frames): + await self.ws.send(silence) + self.bytes_sent += len(silence) + await asyncio.sleep(0.02) + self.log_event("→", f"Sent trailing silence: {self.tail_silence_ms}ms") self.send_completed = True elapsed = time.time() - self.send_start_time @@ -277,54 +324,59 @@ class WavFileClient: async def _handle_event(self, event: dict) -> None: """Handle incoming event.""" - event_type = event.get("event", "unknown") - - if event_type == "answer": - self.log_event("←", "Session ready!") - elif event_type == "speaking": - self.log_event("←", "Speech detected") - elif event_type == "silence": - self.log_event("←", "Silence detected") - elif event_type == "transcript": - # ASR transcript (interim = asrDelta-style, final = asrFinal-style) + event_type = event.get("type", "unknown") + ids = self._event_ids_suffix(event) + if self.track_debug: + print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}") + + if event_type == "hello.ack": + self.log_event("←", f"Handshake acknowledged{ids}") + elif event_type == "session.started": + self.session_ready = True + self.log_event("←", f"Session ready!{ids}") + elif event_type == "config.resolved": + config = event.get("config", {}) + self.log_event("←", f"Config resolved (output={config.get('output', {})}){ids}") + elif event_type == "input.speech_started": + self.log_event("←", f"Speech detected{ids}") + elif event_type == "input.speech_stopped": + self.log_event("←", f"Silence detected{ids}") + elif event_type == "transcript.delta": text = event.get("text", "") - is_final = event.get("isFinal", False) - if is_final: - # Clear interim line and print final - print(" " * 80, end="\r") - self.log_event("←", f"→ You: {text}") - else: - # Interim result - show with indicator (overwrite same line, as in mic_client) - display_text = text[:60] + "..." if len(text) > 60 else text - print(f" [listening] {display_text}".ljust(80), end="\r") - elif event_type == "ttfb": + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [listening] {display_text}".ljust(80), end="\r") + elif event_type == "transcript.final": + text = event.get("text", "") + print(" " * 80, end="\r") + self.log_event("←", f"→ You: {text}{ids}") + elif event_type == "metrics.ttfb": latency_ms = event.get("latencyMs", 0) self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms") - elif event_type == "llmResponse": + elif event_type == "assistant.response.delta": text = event.get("text", "") - is_final = event.get("isFinal", False) - if is_final: - self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}") - elif self.verbose: - # Show streaming chunks only in verbose mode - self.log_event("←", f"LLM: {text}") - elif event_type == "trackStart": + if self.verbose and text: + self.log_event("←", f"LLM: {text}{ids}") + elif event_type == "assistant.response.final": + text = event.get("text", "") + if text: + self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}") + elif event_type == "output.audio.start": self.track_started = True self.response_start_time = time.time() self.waiting_for_first_audio = True - self.log_event("←", "Bot started speaking") - elif event_type == "trackEnd": + self.log_event("←", f"Bot started speaking{ids}") + elif event_type == "output.audio.end": self.track_ended = True - self.log_event("←", "Bot finished speaking") - elif event_type == "interrupt": - self.log_event("←", "Bot interrupted!") + self.log_event("←", f"Bot finished speaking{ids}") + elif event_type == "response.interrupted": + self.log_event("←", f"Bot interrupted!{ids}") elif event_type == "error": - self.log_event("!", f"Error: {event.get('error')}") - elif event_type == "hangup": - self.log_event("←", f"Hangup: {event.get('reason')}") + self.log_event("!", f"Error: {event.get('message')}{ids}") + elif event_type == "session.stopped": + self.log_event("←", f"Session stopped: {event.get('reason')}{ids}") self.running = False else: - self.log_event("←", f"Event: {event_type}") + self.log_event("←", f"Event: {event_type}{ids}") def save_output_wav(self) -> None: """Save received audio to output WAV file.""" @@ -359,11 +411,15 @@ class WavFileClient: # Connect to server await self.connect() - # Wait for answer - await asyncio.sleep(0.5) - # Start receiver task receiver_task = asyncio.create_task(self.receiver()) + + # Wait for session.started before streaming audio + ready_start = time.time() + while self.running and not self.session_ready: + if time.time() - ready_start > 8.0: + raise TimeoutError("Timeout waiting for session.started") + await asyncio.sleep(0.05) # Send audio await self.audio_sender(audio_data) @@ -464,6 +520,21 @@ async def main(): default=16000, help="Target sample rate for audio (default: 16000)" ) + parser.add_argument( + "--app-id", + default="assistant_demo", + help="Stable app/assistant identifier for server-side config lookup" + ) + parser.add_argument( + "--channel", + default="wav_client", + help="Client channel name" + ) + parser.add_argument( + "--config-version-id", + default="local-dev", + help="Optional config version identifier" + ) parser.add_argument( "--chunk-duration", type=int, @@ -481,6 +552,17 @@ async def main(): action="store_true", help="Enable verbose output" ) + parser.add_argument( + "--track-debug", + action="store_true", + help="Print event trackId for protocol debugging" + ) + parser.add_argument( + "--tail-silence-ms", + type=int, + default=800, + help="Trailing silence to send after WAV playback for EOU detection (default: 800)" + ) args = parser.parse_args() @@ -488,10 +570,15 @@ async def main(): url=args.url, input_file=args.input, output_file=args.output, + app_id=args.app_id, + channel=args.channel, + config_version_id=args.config_version_id, sample_rate=args.sample_rate, chunk_duration_ms=args.chunk_duration, wait_time=args.wait_time, - verbose=args.verbose + verbose=args.verbose, + track_debug=args.track_debug, + tail_silence_ms=args.tail_silence_ms, ) await client.run() diff --git a/engine/examples/web_client.html b/engine/examples/web_client.html index aaeb636..3431c02 100644 --- a/engine/examples/web_client.html +++ b/engine/examples/web_client.html @@ -401,6 +401,9 @@ const targetSampleRate = 16000; const playbackStopRampSec = 0.008; + const appId = "assistant_demo"; + const channel = "web_client"; + const configVersionId = "local-dev"; function logLine(type, text, data) { const time = new Date().toLocaleTimeString(); @@ -604,15 +607,35 @@ logLine("sys", `→ ${cmd.type}`, cmd); } + function eventIdsSuffix(event) { + const data = event && typeof event.data === "object" && event.data ? event.data : {}; + const keys = ["turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id"]; + const parts = []; + for (const key of keys) { + const value = data[key] || event[key]; + if (value) parts.push(`${key}=${value}`); + } + return parts.length ? ` [${parts.join(" ")}]` : ""; + } + function handleEvent(event) { const type = event.type || "unknown"; - logLine("event", type, event); + const ids = eventIdsSuffix(event); + logLine("event", `${type}${ids}`, event); if (type === "hello.ack") { sendCommand({ type: "session.start", audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 }, + metadata: { + appId, + channel, + configVersionId, + }, }); } + if (type === "config.resolved") { + logLine("sys", "config.resolved", event.config || {}); + } if (type === "transcript.final") { if (event.text) { setInterim("You", ""); diff --git a/engine/models/ws_v1.py b/engine/models/ws_v1.py index b8f5524..6e67164 100644 --- a/engine/models/ws_v1.py +++ b/engine/models/ws_v1.py @@ -1,7 +1,8 @@ """WS v1 protocol message models and helpers.""" -from typing import Optional, Dict, Any, Literal -from pydantic import BaseModel, Field +from typing import Any, Dict, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, ValidationError def now_ms() -> int: @@ -11,37 +12,66 @@ def now_ms() -> int: return int(time.time() * 1000) +class _StrictModel(BaseModel): + """Protocol models reject unknown fields to enforce WS v1 schema.""" + + model_config = ConfigDict(extra="forbid") + + # Client -> Server messages -class HelloMessage(BaseModel): +class HelloAuth(_StrictModel): + apiKey: Optional[str] = None + jwt: Optional[str] = None + + +class HelloMessage(_StrictModel): type: Literal["hello"] - version: str = Field(..., description="Protocol version, currently v1") - auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}") + version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1") + auth: Optional[HelloAuth] = Field(default=None, description="Auth payload") -class SessionStartMessage(BaseModel): +class SessionStartAudio(_StrictModel): + encoding: Literal["pcm_s16le"] = "pcm_s16le" + sample_rate_hz: Literal[16000] = 16000 + channels: Literal[1] = 1 + + +class SessionStartMessage(_StrictModel): type: Literal["session.start"] - audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata") + audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata") metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata") -class SessionStopMessage(BaseModel): +class SessionStopMessage(_StrictModel): type: Literal["session.stop"] reason: Optional[str] = None -class InputTextMessage(BaseModel): +class InputTextMessage(_StrictModel): type: Literal["input.text"] text: str -class ResponseCancelMessage(BaseModel): +class ResponseCancelMessage(_StrictModel): type: Literal["response.cancel"] graceful: bool = False -class ToolCallResultsMessage(BaseModel): +class ToolCallResultStatus(_StrictModel): + code: int + message: str + + +class ToolCallResult(_StrictModel): + tool_call_id: str + name: str + output: Any = None + status: ToolCallResultStatus + + +class ToolCallResultsMessage(_StrictModel): type: Literal["tool_call.results"] - results: list[Dict[str, Any]] = Field(default_factory=list) + results: list[ToolCallResult] = Field(default_factory=list) CLIENT_MESSAGE_TYPES = { @@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel: msg_class = CLIENT_MESSAGE_TYPES.get(msg_type) if not msg_class: raise ValueError(f"Unknown client message type: {msg_type}") - return msg_class(**data) + try: + return msg_class(**data) + except ValidationError as exc: + details = [] + for err in exc.errors(): + loc = ".".join(str(part) for part in err.get("loc", ())) + msg = err.get("msg", "invalid field") + details.append(f"{loc}: {msg}" if loc else msg) + raise ValueError("; ".join(details)) from exc # Server -> Client event helpers diff --git a/engine/requirements.txt b/engine/requirements.txt index 3d38414..d117414 100644 --- a/engine/requirements.txt +++ b/engine/requirements.txt @@ -17,6 +17,7 @@ pydantic>=2.5.3 pydantic-settings>=2.1.0 python-dotenv>=1.0.0 toml>=0.10.2 +pyyaml>=6.0.1 # Logging loguru>=0.7.2 diff --git a/engine/services/llm.py b/engine/services/llm.py index a25ff26..eb7f89c 100644 --- a/engine/services/llm.py +++ b/engine/services/llm.py @@ -7,10 +7,10 @@ for real-time voice conversation. import os import asyncio import uuid -from typing import AsyncIterator, Optional, List, Dict, Any +from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable from loguru import logger -from app.backend_client import search_knowledge_context +from app.backend_adapters import build_backend_adapter_from_settings from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState # Try to import openai @@ -37,20 +37,21 @@ class OpenAILLMService(BaseLLMService): base_url: Optional[str] = None, system_prompt: Optional[str] = None, knowledge_config: Optional[Dict[str, Any]] = None, + knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None, ): """ Initialize OpenAI LLM service. Args: model: Model name (e.g., "gpt-4o-mini", "gpt-4o") - api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars) base_url: Custom API base URL (for Azure or compatible APIs) system_prompt: Default system prompt for conversations """ super().__init__(model=model) - self.api_key = api_key or os.getenv("OPENAI_API_KEY") - self.base_url = base_url or os.getenv("OPENAI_API_URL") + self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL") self.system_prompt = system_prompt or ( "You are a helpful, friendly voice assistant. " "Keep your responses concise and conversational. " @@ -60,6 +61,11 @@ class OpenAILLMService(BaseLLMService): self.client: Optional[AsyncOpenAI] = None self._cancel_event = asyncio.Event() self._knowledge_config: Dict[str, Any] = knowledge_config or {} + if knowledge_searcher is None: + adapter = build_backend_adapter_from_settings() + self._knowledge_searcher = adapter.search_knowledge_context + else: + self._knowledge_searcher = knowledge_searcher self._tool_schemas: List[Dict[str, Any]] = [] _RAG_DEFAULT_RESULTS = 5 @@ -224,7 +230,7 @@ class OpenAILLMService(BaseLLMService): n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS) n_results = max(1, min(n_results, self._RAG_MAX_RESULTS)) - results = await search_knowledge_context( + results = await self._knowledge_searcher( kb_id=kb_id, query=latest_user, n_results=n_results, diff --git a/engine/services/openai_compatible_asr.py b/engine/services/openai_compatible_asr.py index daf7c04..bcf0fae 100644 --- a/engine/services/openai_compatible_asr.py +++ b/engine/services/openai_compatible_asr.py @@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti import asyncio import io +import os import wave from typing import AsyncIterator, Optional, Callable, Awaitable from loguru import logger @@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService): def __init__( self, - api_key: str, + api_key: Optional[str] = None, + api_url: Optional[str] = None, model: str = "FunAudioLLM/SenseVoiceSmall", sample_rate: int = 16000, language: str = "auto", @@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService): Args: api_key: Provider API key + api_url: Provider API URL (defaults to SiliconFlow endpoint) model: ASR model name or alias sample_rate: Audio sample rate (16000 recommended) language: Language code (auto for automatic detection) @@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService): if not AIOHTTP_AVAILABLE: raise RuntimeError("aiohttp is required for OpenAICompatibleASRService") - self.api_key = api_key + self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY") + self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL self.model = self.MODELS.get(model.lower(), model) self.interim_interval_ms = interim_interval_ms self.min_audio_for_interim_ms = min_audio_for_interim_ms @@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService): async def connect(self) -> None: """Connect to the service.""" + if not self.api_key: + raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.") self._session = aiohttp.ClientSession( headers={ "Authorization": f"Bearer {self.api_key}" @@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService): ) form_data.add_field('model', self.model) - async with self._session.post(self.API_URL, data=form_data) as response: + async with self._session.post(self.api_url, data=form_data) as response: if response.status == 200: result = await response.json() text = result.get("text", "").strip() diff --git a/engine/services/openai_compatible_tts.py b/engine/services/openai_compatible_tts.py index 4967557..1abb1e5 100644 --- a/engine/services/openai_compatible_tts.py +++ b/engine/services/openai_compatible_tts.py @@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService): def __init__( self, api_key: Optional[str] = None, + api_url: Optional[str] = None, voice: str = "anna", model: str = "FunAudioLLM/CosyVoice2-0.5B", sample_rate: int = 16000, @@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService): Initialize OpenAI-compatible TTS service. Args: - api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var) + api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars) + api_url: Provider API URL (defaults to SiliconFlow endpoint) voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) model: Model name sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) @@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService): super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed) - self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY") + self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY") self.model = model - self.api_url = "https://api.siliconflow.cn/v1/audio/speech" + self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech" self._session: Optional[aiohttp.ClientSession] = None self._cancel_event = asyncio.Event() @@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService): async def connect(self) -> None: """Initialize HTTP session.""" if not self.api_key: - raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.") + raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.") self._session = aiohttp.ClientSession( headers={ diff --git a/engine/tests/test_agent_config.py b/engine/tests/test_agent_config.py new file mode 100644 index 0000000..c8698cb --- /dev/null +++ b/engine/tests/test_agent_config.py @@ -0,0 +1,252 @@ +import os +from pathlib import Path + +import pytest + +os.environ.setdefault("LLM_API_KEY", "test-openai-key") +os.environ.setdefault("TTS_API_KEY", "test-tts-key") +os.environ.setdefault("ASR_API_KEY", "test-asr-key") + +from app.config import load_settings + + +def _write_yaml(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + + +def _full_agent_yaml(llm_model: str = "gpt-4o-mini", llm_key: str = "test-openai-key") -> str: + return f""" +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.63 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + provider: openai_compatible + model: {llm_model} + temperature: 0.2 + api_key: {llm_key} + api_url: https://example-llm.invalid/v1 + + tts: + provider: openai_compatible + api_key: test-tts-key + api_url: https://example-tts.invalid/v1/audio/speech + model: FunAudioLLM/CosyVoice2-0.5B + voice: anna + speed: 1.0 + + asr: + provider: openai_compatible + api_key: test-asr-key + api_url: https://example-asr.invalid/v1/audio/transcriptions + model: FunAudioLLM/SenseVoiceSmall + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: You are a strict test assistant. + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 +""".strip() + + +def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + config_dir = tmp_path / "config" / "agents" + _write_yaml( + config_dir / "support.yaml", + _full_agent_yaml(llm_model="gpt-4.1-mini"), + ) + + settings = load_settings( + argv=["--agent-profile", "support"], + ) + + assert settings.llm_model == "gpt-4.1-mini" + assert settings.llm_temperature == 0.2 + assert settings.vad_threshold == 0.63 + assert settings.agent_config_source == "cli_profile" + assert settings.agent_config_path == str((config_dir / "support.yaml").resolve()) + + +def test_cli_path_has_higher_priority_than_env(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + env_file = tmp_path / "config" / "agents" / "env.yaml" + cli_file = tmp_path / "config" / "agents" / "cli.yaml" + + _write_yaml(env_file, _full_agent_yaml(llm_model="env-model")) + _write_yaml(cli_file, _full_agent_yaml(llm_model="cli-model")) + + monkeypatch.setenv("AGENT_CONFIG_PATH", str(env_file)) + + settings = load_settings(argv=["--agent-config", str(cli_file)]) + + assert settings.llm_model == "cli-model" + assert settings.agent_config_source == "cli_path" + assert settings.agent_config_path == str(cli_file.resolve()) + + +def test_default_yaml_is_loaded_without_args_or_env(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + default_file = tmp_path / "config" / "agents" / "default.yaml" + _write_yaml(default_file, _full_agent_yaml(llm_model="from-default")) + + monkeypatch.delenv("AGENT_CONFIG_PATH", raising=False) + monkeypatch.delenv("AGENT_PROFILE", raising=False) + + settings = load_settings(argv=[]) + + assert settings.llm_model == "from-default" + assert settings.agent_config_source == "default" + assert settings.agent_config_path == str(default_file.resolve()) + + +def test_missing_required_agent_settings_fail(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "missing-required.yaml" + _write_yaml( + file_path, + """ +agent: + llm: + model: gpt-4o-mini +""".strip(), + ) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_blank_required_provider_key_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "blank-key.yaml" + _write_yaml(file_path, _full_agent_yaml(llm_key="")) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_missing_tts_api_url_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "missing-tts-url.yaml" + _write_yaml( + file_path, + _full_agent_yaml().replace( + " api_url: https://example-tts.invalid/v1/audio/speech\n", + "", + ), + ) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_missing_asr_api_url_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "missing-asr-url.yaml" + _write_yaml( + file_path, + _full_agent_yaml().replace( + " api_url: https://example-asr.invalid/v1/audio/transcriptions\n", + "", + ), + ) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_agent_yaml_unknown_key_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "bad-agent.yaml" + _write_yaml(file_path, _full_agent_yaml() + "\n unknown_option: true") + + with pytest.raises(ValueError, match="Unknown agent config keys"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_legacy_siliconflow_section_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "legacy-siliconflow.yaml" + _write_yaml( + file_path, + """ +agent: + siliconflow: + api_key: x +""".strip(), + ) + + with pytest.raises(ValueError, match="Section 'siliconflow' is no longer supported"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_agent_yaml_missing_env_reference_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "bad-ref.yaml" + _write_yaml( + file_path, + _full_agent_yaml(llm_key="${UNSET_LLM_API_KEY}"), + ) + + with pytest.raises(ValueError, match="Missing environment variable"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_agent_yaml_tools_list_is_loaded(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "tools-agent.yaml" + _write_yaml( + file_path, + _full_agent_yaml() + + """ + + tools: + - current_time + - name: weather + description: Get weather by city. + parameters: + type: object + properties: + city: + type: string + required: [city] + executor: server +""", + ) + + settings = load_settings(argv=["--agent-config", str(file_path)]) + + assert isinstance(settings.tools, list) + assert settings.tools[0] == "current_time" + assert settings.tools[1]["name"] == "weather" + assert settings.tools[1]["executor"] == "server" + + +def test_agent_yaml_tools_must_be_list(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "bad-tools-agent.yaml" + _write_yaml( + file_path, + _full_agent_yaml() + + """ + + tools: + weather: + executor: server +""", + ) + + with pytest.raises(ValueError, match="Agent config key 'tools' must be a list"): + load_settings(argv=["--agent-config", str(file_path)]) diff --git a/engine/tests/test_backend_adapters.py b/engine/tests/test_backend_adapters.py new file mode 100644 index 0000000..d55f5e2 --- /dev/null +++ b/engine/tests/test_backend_adapters.py @@ -0,0 +1,150 @@ +import aiohttp +import pytest + +from app.backend_adapters import ( + HistoryDisabledBackendAdapter, + HttpBackendAdapter, + NullBackendAdapter, + build_backend_adapter, +) + + +@pytest.mark.asyncio +async def test_build_backend_adapter_without_url_returns_null_adapter(): + adapter = build_backend_adapter( + backend_url=None, + backend_mode="auto", + history_enabled=True, + timeout_sec=3, + ) + assert isinstance(adapter, NullBackendAdapter) + + assert await adapter.fetch_assistant_config("assistant_1") is None + assert ( + await adapter.create_call_record( + user_id=1, + assistant_id="assistant_1", + source="debug", + ) + is None + ) + assert ( + await adapter.add_transcript( + call_id="call_1", + turn_index=0, + speaker="human", + content="hi", + start_ms=0, + end_ms=100, + confidence=0.9, + duration_ms=100, + ) + is False + ) + assert ( + await adapter.finalize_call_record( + call_id="call_1", + status="connected", + duration_seconds=2, + ) + is False + ) + assert await adapter.search_knowledge_context(kb_id="kb_1", query="hello", n_results=3) == [] + assert await adapter.fetch_tool_resource("tool_1") is None + + +@pytest.mark.asyncio +async def test_http_backend_adapter_create_call_record_posts_expected_payload(monkeypatch): + captured = {} + + class _FakeResponse: + def __init__(self, status=200, payload=None): + self.status = status + self._payload = payload if payload is not None else {} + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self._payload + + def raise_for_status(self): + if self.status >= 400: + raise RuntimeError("http_error") + + class _FakeClientSession: + def __init__(self, timeout=None): + captured["timeout"] = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def post(self, url, json=None): + captured["url"] = url + captured["json"] = json + return _FakeResponse(status=200, payload={"id": "call_123"}) + + monkeypatch.setattr("app.backend_adapters.aiohttp.ClientSession", _FakeClientSession) + + adapter = build_backend_adapter( + backend_url="http://localhost:8100", + backend_mode="auto", + history_enabled=True, + timeout_sec=7, + ) + assert isinstance(adapter, HttpBackendAdapter) + + call_id = await adapter.create_call_record( + user_id=99, + assistant_id="assistant_9", + source="debug", + ) + + assert call_id == "call_123" + assert captured["url"] == "http://localhost:8100/api/history" + assert captured["json"] == { + "user_id": 99, + "assistant_id": "assistant_9", + "source": "debug", + "status": "connected", + } + assert isinstance(captured["timeout"], aiohttp.ClientTimeout) + assert captured["timeout"].total == 7 + + +@pytest.mark.asyncio +async def test_backend_mode_disabled_forces_null_even_with_url(): + adapter = build_backend_adapter( + backend_url="http://localhost:8100", + backend_mode="disabled", + history_enabled=True, + timeout_sec=7, + ) + assert isinstance(adapter, NullBackendAdapter) + + +@pytest.mark.asyncio +async def test_history_disabled_wraps_backend_adapter(): + adapter = build_backend_adapter( + backend_url="http://localhost:8100", + backend_mode="auto", + history_enabled=False, + timeout_sec=7, + ) + assert isinstance(adapter, HistoryDisabledBackendAdapter) + assert await adapter.create_call_record(user_id=1, assistant_id="a1", source="debug") is None + assert await adapter.add_transcript( + call_id="c1", + turn_index=0, + speaker="human", + content="hi", + start_ms=0, + end_ms=10, + duration_ms=10, + ) is False diff --git a/engine/tests/test_history_bridge.py b/engine/tests/test_history_bridge.py new file mode 100644 index 0000000..2f9dd80 --- /dev/null +++ b/engine/tests/test_history_bridge.py @@ -0,0 +1,147 @@ +import asyncio +import time + +import pytest + +from core.history_bridge import SessionHistoryBridge + + +class _FakeHistoryWriter: + def __init__(self, *, add_delay_s: float = 0.0, add_result: bool = True): + self.add_delay_s = add_delay_s + self.add_result = add_result + self.created_call_ids = [] + self.transcripts = [] + self.finalize_calls = 0 + self.finalize_statuses = [] + self.finalize_at = None + self.last_transcript_at = None + + async def create_call_record(self, *, user_id: int, assistant_id: str | None, source: str = "debug"): + _ = (user_id, assistant_id, source) + call_id = "call_test_1" + self.created_call_ids.append(call_id) + return call_id + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: float | None = None, + duration_ms: int | None = None, + ) -> bool: + _ = confidence + if self.add_delay_s > 0: + await asyncio.sleep(self.add_delay_s) + self.transcripts.append( + { + "call_id": call_id, + "turn_index": turn_index, + "speaker": speaker, + "content": content, + "start_ms": start_ms, + "end_ms": end_ms, + "duration_ms": duration_ms, + } + ) + self.last_transcript_at = time.monotonic() + return self.add_result + + async def finalize_call_record(self, *, call_id: str, status: str, duration_seconds: int) -> bool: + _ = (call_id, duration_seconds) + self.finalize_calls += 1 + self.finalize_statuses.append(status) + self.finalize_at = time.monotonic() + return True + +@pytest.mark.asyncio +async def test_slow_backend_does_not_block_enqueue(): + writer = _FakeHistoryWriter(add_delay_s=0.15, add_result=True) + bridge = SessionHistoryBridge( + history_writer=writer, + enabled=True, + queue_max_size=32, + retry_max_attempts=0, + retry_backoff_sec=0.01, + finalize_drain_timeout_sec=1.0, + ) + + try: + call_id = await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug") + assert call_id == "call_test_1" + + t0 = time.perf_counter() + queued = bridge.enqueue_turn(role="user", text="hello world") + elapsed_s = time.perf_counter() - t0 + + assert queued is True + assert elapsed_s < 0.02 + + await bridge.finalize(status="connected") + assert len(writer.transcripts) == 1 + assert writer.finalize_calls == 1 + finally: + await bridge.shutdown() + + +@pytest.mark.asyncio +async def test_failing_backend_retries_but_enqueue_remains_non_blocking(): + writer = _FakeHistoryWriter(add_delay_s=0.01, add_result=False) + bridge = SessionHistoryBridge( + history_writer=writer, + enabled=True, + queue_max_size=32, + retry_max_attempts=2, + retry_backoff_sec=0.01, + finalize_drain_timeout_sec=0.5, + ) + + try: + await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug") + t0 = time.perf_counter() + assert bridge.enqueue_turn(role="assistant", text="retry me") + elapsed_s = time.perf_counter() - t0 + assert elapsed_s < 0.02 + + await bridge.finalize(status="connected") + + # Initial try + 2 retries + assert len(writer.transcripts) == 3 + assert writer.finalize_calls == 1 + finally: + await bridge.shutdown() + + +@pytest.mark.asyncio +async def test_finalize_is_idempotent_and_waits_for_queue_drain(): + writer = _FakeHistoryWriter(add_delay_s=0.05, add_result=True) + bridge = SessionHistoryBridge( + history_writer=writer, + enabled=True, + queue_max_size=32, + retry_max_attempts=0, + retry_backoff_sec=0.01, + finalize_drain_timeout_sec=1.0, + ) + + try: + await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug") + assert bridge.enqueue_turn(role="user", text="first") + + ok_1 = await bridge.finalize(status="connected") + ok_2 = await bridge.finalize(status="connected") + + assert ok_1 is True + assert ok_2 is True + assert len(writer.transcripts) == 1 + assert writer.finalize_calls == 1 + assert writer.last_transcript_at is not None + assert writer.finalize_at is not None + assert writer.finalize_at >= writer.last_transcript_at + finally: + await bridge.shutdown() diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index e5f241b..6337edd 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -92,16 +92,44 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl return pipeline, events +def test_pipeline_uses_default_tools_from_settings(monkeypatch): + monkeypatch.setattr( + "core.duplex_pipeline.settings.tools", + [ + "current_time", + { + "name": "weather", + "description": "Get weather by city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "executor": "server", + }, + ], + ) + pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + + cfg = pipeline.resolved_runtime_config() + assert cfg["tools"]["allowlist"] == ["current_time", "weather"] + + schemas = pipeline._resolved_tool_schemas() + names = [s.get("function", {}).get("name") for s in schemas if isinstance(s, dict)] + assert "current_time" in names + assert "weather" in names + + @pytest.mark.asyncio async def test_ws_message_parses_tool_call_results(): msg = parse_client_message( { "type": "tool_call.results", - "results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}], + "results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}], } ) assert isinstance(msg, ToolCallResultsMessage) - assert msg.results[0]["tool_call_id"] == "call_1" + assert msg.results[0].tool_call_id == "call_1" @pytest.mark.asyncio diff --git a/web/services/apiClient.ts b/web/services/apiClient.ts index d8c5232..afe34a6 100644 --- a/web/services/apiClient.ts +++ b/web/services/apiClient.ts @@ -6,6 +6,7 @@ const getApiBaseUrl = (): string => { const configured = import.meta.env.VITE_API_BASE_URL || DEFAULT_API_BASE_URL; return trimTrailingSlash(configured); }; +export { getApiBaseUrl }; type RequestOptions = { method?: 'GET' | 'POST' | 'PUT' | 'DELETE'; diff --git a/web/services/backendApi.ts b/web/services/backendApi.ts index 90b52f8..b6a65ab 100644 --- a/web/services/backendApi.ts +++ b/web/services/backendApi.ts @@ -1,7 +1,11 @@ import { ASRModel, Assistant, CallLog, InteractionDetail, KnowledgeBase, KnowledgeDocument, LLMModel, Tool, Voice, Workflow, WorkflowEdge, WorkflowNode } from '../types'; -import { apiRequest } from './apiClient'; +import { apiRequest, getApiBaseUrl } from './apiClient'; type AnyRecord = Record; +const DEFAULT_LIST_LIMIT = 1000; + +const withLimit = (path: string, limit: number = DEFAULT_LIST_LIMIT): string => + `${path}${path.includes('?') ? '&' : '?'}limit=${limit}`; const readField = (obj: AnyRecord, keys: string[], fallback: T): T => { for (const key of keys) { @@ -185,7 +189,8 @@ const mapKnowledgeBase = (raw: AnyRecord): KnowledgeBase => ({ const toHistoryRow = (raw: AnyRecord, assistantNameMap: Map): CallLog => { const assistantId = readField(raw, ['assistant_id', 'assistantId'], ''); const startTime = normalizeDateLabel(readField(raw, ['started_at', 'startTime'], '')); - const type = readField(raw, ['type'], 'text'); + const rawType = String(readField(raw, ['type'], 'text')).toLowerCase(); + const type: CallLog['type'] = rawType === 'audio' || rawType === 'video' ? rawType : 'text'; return { id: String(readField(raw, ['id'], '')), source: readField(raw, ['source'], 'debug') as 'debug' | 'external', @@ -193,7 +198,7 @@ const toHistoryRow = (raw: AnyRecord, assistantNameMap: Map): Ca startTime, duration: formatDuration(readField(raw, ['duration_seconds', 'durationSeconds'], 0)), agentName: assistantNameMap.get(String(assistantId)) || String(assistantId || 'Unknown Assistant'), - type: type === 'audio' || type === 'video' ? type : 'text', + type, details: [], }; }; @@ -209,7 +214,7 @@ const toHistoryDetails = (raw: AnyRecord): InteractionDetail[] => { }; export const fetchAssistants = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/assistants'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/assistants')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapAssistant(item)); }; @@ -276,6 +281,8 @@ export const deleteAssistant = async (id: string): Promise => { export interface AssistantRuntimeConfigResponse { assistantId: string; + configVersionId?: string; + assistant?: Record; sessionStartMetadata: Record; sources?: { llmModelId?: string; @@ -286,11 +293,11 @@ export interface AssistantRuntimeConfigResponse { } export const fetchAssistantRuntimeConfig = async (assistantId: string): Promise => { - return apiRequest(`/assistants/${assistantId}/runtime-config`); + return apiRequest(`/assistants/${assistantId}/config`); }; export const fetchVoices = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/voices')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapVoice(item)); }; @@ -351,7 +358,7 @@ export const previewVoice = async (id: string, text: string, speed?: number, api }; export const fetchASRModels = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/asr'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/asr')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapASRModel(item)); }; @@ -418,8 +425,7 @@ export const previewASRModel = async ( formData.append('api_key', options.apiKey); } - const base = (import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8100/api').replace(/\/+$/, ''); - const url = `${base}/asr/${id}/preview`; + const url = `${getApiBaseUrl()}/asr/${id}/preview`; const response = await fetch(url, { method: 'POST', body: formData, @@ -441,7 +447,7 @@ export const previewASRModel = async ( }; export const fetchLLMModels = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/llm'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/llm')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapLLMModel(item)); }; @@ -501,7 +507,7 @@ export const previewLLMModel = async ( }; export const fetchTools = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/tools/resources'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/tools/resources')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapTool(item)); }; @@ -544,18 +550,14 @@ export const deleteTool = async (id: string): Promise => { }; export const fetchWorkflows = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/workflows'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/workflows')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapWorkflow(item)); }; export const fetchWorkflowById = async (id: string): Promise => { - const list = await fetchWorkflows(); - const workflow = list.find((item) => item.id === id); - if (!workflow) { - throw new Error('Workflow not found'); - } - return workflow; + const response = await apiRequest(`/workflows/${id}`); + return mapWorkflow(response); }; export const createWorkflow = async (data: Partial): Promise => { @@ -589,7 +591,7 @@ export const deleteWorkflow = async (id: string): Promise => { }; export const fetchKnowledgeBases = async (): Promise => { - const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/knowledge/bases'); + const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/knowledge/bases')); const list = Array.isArray(response) ? response : (response.list || []); return list.map((item) => mapKnowledgeBase(item)); }; @@ -641,8 +643,7 @@ export const uploadKnowledgeDocument = async (kbId: string, file: File): Promise formData.append('size', `${file.size} bytes`); formData.append('file_type', file.type || 'application/octet-stream'); - const base = (import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8100/api').replace(/\/+$/, ''); - const url = `${base}/knowledge/bases/${kbId}/documents`; + const url = `${getApiBaseUrl()}/knowledge/bases/${kbId}/documents`; const response = await fetch(url, { method: 'POST', body: formData, @@ -716,8 +717,8 @@ export const searchKnowledgeBase = async ( export const fetchHistory = async (): Promise => { const [historyResp, assistantsResp] = await Promise.all([ - apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/history'), - apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/assistants'), + apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/history')), + apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/assistants')), ]); const assistantList = Array.isArray(assistantsResp) ? assistantsResp : (assistantsResp.list || []); diff --git a/web/tsconfig.json b/web/tsconfig.json index 2c6eed5..d15d242 100644 --- a/web/tsconfig.json +++ b/web/tsconfig.json @@ -11,7 +11,8 @@ ], "skipLibCheck": true, "types": [ - "node" + "node", + "vite/client" ], "moduleResolution": "bundler", "isolatedModules": true, @@ -26,4 +27,4 @@ "allowImportingTsExtensions": true, "noEmit": true } -} \ No newline at end of file +}