Unify db api

This commit is contained in:
Xin Wang
2026-02-26 01:58:39 +08:00
parent 56f8aa2191
commit 72ed7d0512
40 changed files with 3926 additions and 593 deletions

6
.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
# OS artifacts
.DS_Store
Thumbs.db
# Workspace runtime data
data/

8
api/.gitignore vendored
View File

@@ -36,8 +36,12 @@ env/
*.sqlite
*.sqlite3
# Vector store data
data/vector_store/
# Runtime data (SQLite, vector store, uploads, generated artifacts)
data/**
!data/
!data/.gitkeep
!data/vector_store/
data/vector_store/**
!data/vector_store/.gitkeep
# IDE

View File

@@ -1,13 +1,13 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional
from typing import Any, Dict, List, Optional
import uuid
from datetime import datetime
from ..db import get_db
from ..models import Assistant, LLMModel, ASRModel, Voice
from ..schemas import (
AssistantCreate, AssistantUpdate, AssistantOut
AssistantCreate, AssistantUpdate, AssistantOut, AssistantEngineConfigResponse
)
router = APIRouter(prefix="/assistants", tags=["Assistants"])
@@ -52,8 +52,13 @@ def _normalize_openai_compatible_voice_key(voice_value: str, model: str) -> str:
return f"{model_name}:{voice_id}"
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
metadata = {
def _config_version_id(assistant: Assistant) -> str:
updated = assistant.updated_at or assistant.created_at or datetime.utcnow()
return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}"
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
metadata: Dict[str, Any] = {
"systemPrompt": assistant.prompt or "",
"firstTurnMode": assistant.first_turn_mode or "bot_first",
"greeting": assistant.opener or "",
@@ -64,10 +69,29 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
"minDurationMs": int(assistant.interruption_sensitivity or 500),
},
"services": {},
"tools": assistant.tools or [],
"history": {
"assistantId": assistant.id,
"userId": int(assistant.user_id or 1),
"source": "debug",
},
}
warnings = []
warnings: List[str] = []
if assistant.llm_model_id:
config_mode = str(assistant.config_mode or "platform").strip().lower()
if config_mode in {"dify", "fastgpt"}:
metadata["services"]["llm"] = {
"provider": "openai",
"model": "",
"apiKey": assistant.api_key,
"baseUrl": assistant.api_url,
}
if not (assistant.api_url or "").strip():
warnings.append(f"External LLM API URL is empty for mode: {assistant.config_mode}")
if not (assistant.api_key or "").strip():
warnings.append(f"External LLM API key is empty for mode: {assistant.config_mode}")
elif assistant.llm_model_id:
llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first()
if llm:
metadata["services"]["llm"] = {
@@ -87,6 +111,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
"provider": asr_provider,
"model": asr.model_name or asr.name,
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
"baseUrl": asr.base_url if asr_provider == "openai_compatible" else None,
}
else:
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
@@ -107,6 +132,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
"provider": tts_provider,
"model": model,
"apiKey": voice.api_key if tts_provider == "openai_compatible" else None,
"baseUrl": voice.base_url if tts_provider == "openai_compatible" else None,
"voice": runtime_voice,
"speed": assistant.speed or voice.speed,
}
@@ -126,10 +152,21 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
"kbId": assistant.knowledge_base_id,
"nResults": 5,
}
return metadata, warnings
def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[str, Any]:
session_metadata, warnings = _resolve_runtime_metadata(db, assistant)
config_version_id = _config_version_id(assistant)
assistant_cfg = dict(session_metadata)
assistant_cfg["assistantId"] = assistant.id
assistant_cfg["configVersionId"] = config_version_id
return {
"assistantId": assistant.id,
"sessionStartMetadata": metadata,
"configVersionId": config_version_id,
"assistant": assistant_cfg,
"sessionStartMetadata": session_metadata,
"sources": {
"llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id,
@@ -219,13 +256,22 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
return assistant_to_dict(assistant)
@router.get("/{id}/runtime-config")
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
"""Resolve assistant runtime config for engine WS session.start metadata."""
@router.get("/{id}/config", response_model=AssistantEngineConfigResponse)
def get_assistant_config(id: str, db: Session = Depends(get_db)):
"""Canonical engine config endpoint consumed by engine backend adapter."""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return _resolve_runtime_metadata(db, assistant)
return _build_engine_assistant_config(db, assistant)
@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse)
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
"""Legacy alias for resolved engine runtime config."""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return _build_engine_assistant_config(db, assistant)
@router.post("", response_model=AssistantOut)

View File

@@ -333,6 +333,35 @@ class AssistantOut(AssistantBase):
from_attributes = True
class AssistantRuntimeMetadata(BaseModel):
"""Canonical runtime metadata payload consumed by engine session.start."""
model_config = ConfigDict(extra="allow")
systemPrompt: str = ""
firstTurnMode: str = "bot_first"
greeting: str = ""
generatedOpenerEnabled: bool = False
output: Dict[str, Any] = Field(default_factory=dict)
bargeIn: Dict[str, Any] = Field(default_factory=dict)
services: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
tools: List[Any] = Field(default_factory=list)
knowledgeBaseId: Optional[str] = None
knowledge: Dict[str, Any] = Field(default_factory=dict)
history: Dict[str, Any] = Field(default_factory=dict)
assistantId: Optional[str] = None
configVersionId: Optional[str] = None
class AssistantEngineConfigResponse(BaseModel):
assistantId: str
configVersionId: Optional[str] = None
assistant: AssistantRuntimeMetadata
sessionStartMetadata: AssistantRuntimeMetadata
sources: Dict[str, Optional[str]] = Field(default_factory=dict)
warnings: List[str] = Field(default_factory=list)
class AssistantStats(BaseModel):
assistant_id: str
total_calls: int = 0

View File

@@ -183,12 +183,16 @@ class TestAssistantAPI:
def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data):
"""Test resolved runtime config endpoint for WS session.start metadata."""
sample_asr_model_data["vendor"] = "OpenAI Compatible"
llm_resp = client.post("/api/llm", json=sample_llm_model_data)
assert llm_resp.status_code == 200
asr_resp = client.post("/api/asr", json=sample_asr_model_data)
assert asr_resp.status_code == 200
sample_voice_data["vendor"] = "OpenAI Compatible"
sample_voice_data["base_url"] = "https://tts.example.com/v1/audio/speech"
sample_voice_data["api_key"] = "test-voice-key"
voice_resp = client.post("/api/voices", json=sample_voice_data)
assert voice_resp.status_code == 200
voice_id = voice_resp.json()["id"]
@@ -215,7 +219,26 @@ class TestAssistantAPI:
assert metadata["greeting"] == "runtime opener"
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"]
assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"]
assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"]
def test_get_engine_config_endpoint(self, client, sample_assistant_data):
"""Test canonical assistant config endpoint consumed by engine backend adapter."""
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
config_resp = client.get(f"/api/assistants/{assistant_id}/config")
assert config_resp.status_code == 200
payload = config_resp.json()
assert payload["assistantId"] == assistant_id
assert payload["assistant"]["assistantId"] == assistant_id
assert payload["assistant"]["configVersionId"].startswith(f"asst_{assistant_id}_")
assert payload["assistant"]["systemPrompt"] == sample_assistant_data["prompt"]
assert payload["sessionStartMetadata"]["systemPrompt"] == sample_assistant_data["prompt"]
assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id
def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data):
sample_assistant_data["voiceOutputEnabled"] = False

View File

@@ -11,9 +11,16 @@ PORT=8000
# EXTERNAL_IP=1.2.3.4
# Backend bridge (optional)
# BACKEND_MODE=auto|http|disabled
BACKEND_MODE=auto
BACKEND_URL=http://127.0.0.1:8100
BACKEND_TIMEOUT_SEC=10
HISTORY_ENABLED=true
HISTORY_DEFAULT_USER_ID=1
HISTORY_QUEUE_MAX_SIZE=256
HISTORY_RETRY_MAX_ATTEMPTS=2
HISTORY_RETRY_BACKOFF_SEC=0.2
HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC=1.5
# Audio
SAMPLE_RATE=16000
@@ -23,57 +30,21 @@ CHUNK_SIZE_MS=20
DEFAULT_CODEC=pcm
MAX_AUDIO_BUFFER_SECONDS=30
# VAD / EOU
VAD_TYPE=silero
VAD_MODEL_PATH=data/vad/silero_vad.onnx
# Higher = stricter speech detection (fewer false positives, more misses).
VAD_THRESHOLD=0.5
# Require this much continuous speech before utterance can be valid.
VAD_MIN_SPEECH_DURATION_MS=100
# Silence duration required to finalize one user turn.
VAD_EOU_THRESHOLD_MS=800
# Agent profile selection (optional fallback when CLI args are not used)
# Prefer CLI:
# python -m app.main --agent-config config/agents/default.yaml
# python -m app.main --agent-profile default
# AGENT_CONFIG_PATH=config/agents/default.yaml
# AGENT_PROFILE=default
AGENT_CONFIG_DIR=config/agents
# LLM
OPENAI_API_KEY=your_openai_api_key_here
# Optional for OpenAI-compatible providers.
# OPENAI_API_URL=https://api.openai.com/v1
LLM_MODEL=gpt-4o-mini
LLM_TEMPERATURE=0.7
# TTS
# edge: no API key needed
# openai_compatible: compatible with SiliconFlow-style endpoints
TTS_PROVIDER=openai_compatible
TTS_VOICE=anna
TTS_SPEED=1.0
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
# ASR
ASR_PROVIDER=openai_compatible
# Interim cadence and minimum audio before interim decode.
ASR_INTERIM_INTERVAL_MS=500
ASR_MIN_AUDIO_MS=300
# ASR start gate: ignore micro-noise, then commit to one turn once started.
ASR_START_MIN_SPEECH_MS=160
# Pre-roll protects beginning phonemes.
ASR_PRE_SPEECH_MS=240
# Tail silence protects ending phonemes.
ASR_FINAL_TAIL_MS=120
# Duplex behavior
DUPLEX_ENABLED=true
# DUPLEX_GREETING=Hello! How can I help you today?
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
# Barge-in (user interrupting assistant)
# Min user speech duration needed to interrupt assistant audio.
BARGE_IN_MIN_DURATION_MS=200
# Allowed silence during potential barge-in (ms) before reset.
BARGE_IN_SILENCE_TOLERANCE_MS=60
# Optional: provider credentials referenced from YAML, e.g. ${LLM_API_KEY}
# LLM_API_KEY=your_llm_api_key_here
# LLM_API_URL=https://api.openai.com/v1
# TTS_API_KEY=your_tts_api_key_here
# TTS_API_URL=https://api.example.com/v1/audio/speech
# ASR_API_KEY=your_asr_api_key_here
# ASR_API_URL=https://api.example.com/v1/audio/transcriptions
# Logging
LOG_LEVEL=INFO

2
engine/.gitignore vendored
View File

@@ -146,3 +146,5 @@ cython_debug/
recordings/
logs/
running/
config/agents/default.yaml

View File

@@ -14,6 +14,57 @@ It is currently in an early, experimental stage.
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
使用 agent profile推荐
```
python -m app.main --agent-profile default
```
使用指定 YAML
```
python -m app.main --agent-config config/agents/default.yaml
```
Agent 配置路径优先级
1. `--agent-config`
2. `--agent-profile`(映射到 `config/agents/<profile>.yaml`
3. `AGENT_CONFIG_PATH`
4. `AGENT_PROFILE`
5. `config/agents/default.yaml`(若存在)
说明
- Agent 相关配置是严格模式YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。
- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`
- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider``api_key``api_url``model` 配置。
- 现在支持在 Agent YAML 中配置 `agent.tools`(列表),用于声明运行时可调用工具。
- 工具配置示例见 `config/agents/tools.yaml`
## Backend Integration
Engine runtime now supports adapter-based backend integration:
- `BACKEND_MODE=auto|http|disabled`
- `BACKEND_URL` + `BACKEND_TIMEOUT_SEC`
- `HISTORY_ENABLED=true|false`
Behavior:
- `auto`: use HTTP backend only when `BACKEND_URL` is set, otherwise engine-only mode.
- `http`: force HTTP backend; falls back to engine-only mode when URL is missing.
- `disabled`: force engine-only mode (no backend calls).
History write path is now asynchronous and buffered per session:
- `HISTORY_QUEUE_MAX_SIZE`
- `HISTORY_RETRY_MAX_ATTEMPTS`
- `HISTORY_RETRY_BACKOFF_SEC`
- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC`
This keeps turn processing responsive even when backend history APIs are slow/failing.
Detailed notes: `docs/backend_integration.md`.
测试
```
@@ -28,4 +79,4 @@ python mic_client.py
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.
See `docs/ws_v1_schema.md`.

View File

@@ -0,0 +1,357 @@
"""Backend adapter implementations for engine integration ports."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import aiohttp
from loguru import logger
from app.config import settings
class NullBackendAdapter:
"""No-op adapter for engine-only runtime without backend dependencies."""
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
_ = assistant_id
return None
async def create_call_record(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str = "debug",
) -> Optional[str]:
_ = (user_id, assistant_id, source)
return None
async def add_transcript(
self,
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> bool:
_ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms)
return False
async def finalize_call_record(
self,
*,
call_id: str,
status: str,
duration_seconds: int,
) -> bool:
_ = (call_id, status, duration_seconds)
return False
async def search_knowledge_context(
self,
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
_ = (kb_id, query, n_results)
return []
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
_ = tool_id
return None
class HistoryDisabledBackendAdapter:
"""Adapter wrapper that disables history writes while keeping reads available."""
def __init__(self, delegate: HttpBackendAdapter | NullBackendAdapter):
self._delegate = delegate
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
return await self._delegate.fetch_assistant_config(assistant_id)
async def create_call_record(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str = "debug",
) -> Optional[str]:
_ = (user_id, assistant_id, source)
return None
async def add_transcript(
self,
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> bool:
_ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms)
return False
async def finalize_call_record(
self,
*,
call_id: str,
status: str,
duration_seconds: int,
) -> bool:
_ = (call_id, status, duration_seconds)
return False
async def search_knowledge_context(
self,
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
return await self._delegate.search_knowledge_context(
kb_id=kb_id,
query=query,
n_results=n_results,
)
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
return await self._delegate.fetch_tool_resource(tool_id)
class HttpBackendAdapter:
"""HTTP implementation of backend integration ports."""
def __init__(self, backend_url: str, timeout_sec: int = 10):
base_url = str(backend_url or "").strip().rstrip("/")
if not base_url:
raise ValueError("backend_url is required for HttpBackendAdapter")
self._base_url = base_url
self._timeout_sec = timeout_sec
def _timeout(self) -> aiohttp.ClientTimeout:
return aiohttp.ClientTimeout(total=self._timeout_sec)
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
"""Fetch assistant config payload from backend API.
Expected response shape:
{
"assistant": {...},
"voice": {...} | null
}
"""
url = f"{self._base_url}/api/assistants/{assistant_id}/config"
try:
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
async with session.get(url) as resp:
if resp.status == 404:
logger.warning(f"Assistant config not found: {assistant_id}")
return None
resp.raise_for_status()
payload = await resp.json()
if not isinstance(payload, dict):
logger.warning("Assistant config payload is not a dict; ignoring")
return None
return payload
except Exception as exc:
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
return None
async def create_call_record(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str = "debug",
) -> Optional[str]:
"""Create a call record via backend history API and return call_id."""
url = f"{self._base_url}/api/history"
payload: Dict[str, Any] = {
"user_id": user_id,
"assistant_id": assistant_id,
"source": source,
"status": "connected",
}
try:
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
async with session.post(url, json=payload) as resp:
resp.raise_for_status()
data = await resp.json()
call_id = str((data or {}).get("id") or "")
return call_id or None
except Exception as exc:
logger.warning(f"Failed to create history call record: {exc}")
return None
async def add_transcript(
self,
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> bool:
"""Append a transcript segment to backend history."""
if not call_id:
return False
url = f"{self._base_url}/api/history/{call_id}/transcripts"
payload: Dict[str, Any] = {
"turn_index": turn_index,
"speaker": speaker,
"content": content,
"confidence": confidence,
"start_ms": start_ms,
"end_ms": end_ms,
"duration_ms": duration_ms,
}
try:
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
async with session.post(url, json=payload) as resp:
resp.raise_for_status()
return True
except Exception as exc:
logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}")
return False
async def finalize_call_record(
self,
*,
call_id: str,
status: str,
duration_seconds: int,
) -> bool:
"""Finalize a call record with status and duration."""
if not call_id:
return False
url = f"{self._base_url}/api/history/{call_id}"
payload: Dict[str, Any] = {
"status": status,
"duration_seconds": duration_seconds,
}
try:
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
async with session.put(url, json=payload) as resp:
resp.raise_for_status()
return True
except Exception as exc:
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
return False
async def search_knowledge_context(
self,
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search backend knowledge base and return retrieval results."""
if not kb_id or not query.strip():
return []
try:
safe_n_results = max(1, int(n_results))
except (TypeError, ValueError):
safe_n_results = 5
url = f"{self._base_url}/api/knowledge/search"
payload: Dict[str, Any] = {
"kb_id": kb_id,
"query": query,
"nResults": safe_n_results,
}
try:
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
async with session.post(url, json=payload) as resp:
if resp.status == 404:
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
return []
resp.raise_for_status()
data = await resp.json()
if not isinstance(data, dict):
return []
results = data.get("results", [])
if not isinstance(results, list):
return []
return [r for r in results if isinstance(r, dict)]
except Exception as exc:
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
return []
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
"""Fetch tool resource configuration from backend API."""
if not tool_id:
return None
url = f"{self._base_url}/api/tools/resources/{tool_id}"
try:
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
async with session.get(url) as resp:
if resp.status == 404:
return None
resp.raise_for_status()
data = await resp.json()
return data if isinstance(data, dict) else None
except Exception as exc:
logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}")
return None
def build_backend_adapter(
*,
backend_url: Optional[str],
backend_mode: str = "auto",
history_enabled: bool = True,
timeout_sec: int = 10,
) -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter:
"""Create backend adapter implementation based on runtime settings."""
mode = str(backend_mode or "auto").strip().lower()
has_url = bool(str(backend_url or "").strip())
base_adapter: HttpBackendAdapter | NullBackendAdapter
if mode in {"disabled", "off", "none", "null", "engine_only", "engine-only"}:
base_adapter = NullBackendAdapter()
elif mode == "http":
if has_url:
base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec)
else:
logger.warning("BACKEND_MODE=http but BACKEND_URL is empty; falling back to NullBackendAdapter")
base_adapter = NullBackendAdapter()
else:
if has_url:
base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec)
else:
base_adapter = NullBackendAdapter()
if not history_enabled:
return HistoryDisabledBackendAdapter(base_adapter)
return base_adapter
def build_backend_adapter_from_settings() -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter:
"""Create backend adapter using current app settings."""
return build_backend_adapter(
backend_url=settings.backend_url,
backend_mode=settings.backend_mode,
history_enabled=settings.history_enabled,
timeout_sec=settings.backend_timeout_sec,
)

View File

@@ -1,56 +1,19 @@
"""Backend API client for assistant config and history persistence."""
"""Compatibility wrappers around backend adapter implementations."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import aiohttp
from loguru import logger
from app.backend_adapters import build_backend_adapter_from_settings
from app.config import settings
def _adapter():
return build_backend_adapter_from_settings()
async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]:
"""Fetch assistant config payload from backend API.
Expected response shape:
{
"assistant": {...},
"voice": {...} | null
}
"""
if not settings.backend_url:
logger.warning("BACKEND_URL not set; skipping assistant config fetch")
return None
url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config"
timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
if resp.status == 404:
logger.warning(f"Assistant config not found: {assistant_id}")
return None
resp.raise_for_status()
payload = await resp.json()
if not isinstance(payload, dict):
logger.warning("Assistant config payload is not a dict; ignoring")
return None
return payload
except Exception as exc:
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
return None
def _backend_base_url() -> Optional[str]:
if not settings.backend_url:
return None
return settings.backend_url.rstrip("/")
def _timeout() -> aiohttp.ClientTimeout:
return aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
"""Fetch assistant config payload from backend adapter."""
return await _adapter().fetch_assistant_config(assistant_id)
async def create_history_call_record(
@@ -60,28 +23,11 @@ async def create_history_call_record(
source: str = "debug",
) -> Optional[str]:
"""Create a call record via backend history API and return call_id."""
base_url = _backend_base_url()
if not base_url:
return None
url = f"{base_url}/api/history"
payload: Dict[str, Any] = {
"user_id": user_id,
"assistant_id": assistant_id,
"source": source,
"status": "connected",
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
resp.raise_for_status()
data = await resp.json()
call_id = str((data or {}).get("id") or "")
return call_id or None
except Exception as exc:
logger.warning(f"Failed to create history call record: {exc}")
return None
return await _adapter().create_call_record(
user_id=user_id,
assistant_id=assistant_id,
source=source,
)
async def add_history_transcript(
@@ -96,29 +42,16 @@ async def add_history_transcript(
duration_ms: Optional[int] = None,
) -> bool:
"""Append a transcript segment to backend history."""
base_url = _backend_base_url()
if not base_url or not call_id:
return False
url = f"{base_url}/api/history/{call_id}/transcripts"
payload: Dict[str, Any] = {
"turn_index": turn_index,
"speaker": speaker,
"content": content,
"confidence": confidence,
"start_ms": start_ms,
"end_ms": end_ms,
"duration_ms": duration_ms,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
resp.raise_for_status()
return True
except Exception as exc:
logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}")
return False
return await _adapter().add_transcript(
call_id=call_id,
turn_index=turn_index,
speaker=speaker,
content=content,
start_ms=start_ms,
end_ms=end_ms,
confidence=confidence,
duration_ms=duration_ms,
)
async def finalize_history_call_record(
@@ -128,24 +61,11 @@ async def finalize_history_call_record(
duration_seconds: int,
) -> bool:
"""Finalize a call record with status and duration."""
base_url = _backend_base_url()
if not base_url or not call_id:
return False
url = f"{base_url}/api/history/{call_id}"
payload: Dict[str, Any] = {
"status": status,
"duration_seconds": duration_seconds,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.put(url, json=payload) as resp:
resp.raise_for_status()
return True
except Exception as exc:
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
return False
return await _adapter().finalize_call_record(
call_id=call_id,
status=status,
duration_seconds=duration_seconds,
)
async def search_knowledge_context(
@@ -155,57 +75,13 @@ async def search_knowledge_context(
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search backend knowledge base and return retrieval results."""
base_url = _backend_base_url()
if not base_url:
return []
if not kb_id or not query.strip():
return []
try:
safe_n_results = max(1, int(n_results))
except (TypeError, ValueError):
safe_n_results = 5
url = f"{base_url}/api/knowledge/search"
payload: Dict[str, Any] = {
"kb_id": kb_id,
"query": query,
"nResults": safe_n_results,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
if resp.status == 404:
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
return []
resp.raise_for_status()
data = await resp.json()
if not isinstance(data, dict):
return []
results = data.get("results", [])
if not isinstance(results, list):
return []
return [r for r in results if isinstance(r, dict)]
except Exception as exc:
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
return []
return await _adapter().search_knowledge_context(
kb_id=kb_id,
query=query,
n_results=n_results,
)
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
"""Fetch tool resource configuration from backend API."""
base_url = _backend_base_url()
if not base_url or not tool_id:
return None
url = f"{base_url}/api/tools/resources/{tool_id}"
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.get(url) as resp:
if resp.status == 404:
return None
resp.raise_for_status()
data = await resp.json()
return data if isinstance(data, dict) else None
except Exception as exc:
logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}")
return None
return await _adapter().fetch_tool_resource(tool_id)

View File

@@ -1,9 +1,360 @@
"""Configuration management using Pydantic settings."""
"""Configuration management using Pydantic settings and agent YAML profiles."""
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
import json
try:
import yaml
except ImportError: # pragma: no cover - validated when agent YAML is used
yaml = None
_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}")
_DEFAULT_AGENT_CONFIG_DIR = "config/agents"
_DEFAULT_AGENT_CONFIG_FILE = "default.yaml"
_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
"vad": {
"type": "vad_type",
"model_path": "vad_model_path",
"threshold": "vad_threshold",
"min_speech_duration_ms": "vad_min_speech_duration_ms",
"eou_threshold_ms": "vad_eou_threshold_ms",
},
"llm": {
"provider": "llm_provider",
"model": "llm_model",
"temperature": "llm_temperature",
"api_key": "llm_api_key",
"api_url": "llm_api_url",
},
"tts": {
"provider": "tts_provider",
"api_key": "tts_api_key",
"api_url": "tts_api_url",
"model": "tts_model",
"voice": "tts_voice",
"speed": "tts_speed",
},
"asr": {
"provider": "asr_provider",
"api_key": "asr_api_key",
"api_url": "asr_api_url",
"model": "asr_model",
"interim_interval_ms": "asr_interim_interval_ms",
"min_audio_ms": "asr_min_audio_ms",
"start_min_speech_ms": "asr_start_min_speech_ms",
"pre_speech_ms": "asr_pre_speech_ms",
"final_tail_ms": "asr_final_tail_ms",
},
"duplex": {
"enabled": "duplex_enabled",
"greeting": "duplex_greeting",
"system_prompt": "duplex_system_prompt",
},
"barge_in": {
"min_duration_ms": "barge_in_min_duration_ms",
"silence_tolerance_ms": "barge_in_silence_tolerance_ms",
},
}
_AGENT_SETTING_KEYS = {
"vad_type",
"vad_model_path",
"vad_threshold",
"vad_min_speech_duration_ms",
"vad_eou_threshold_ms",
"llm_provider",
"llm_api_key",
"llm_api_url",
"llm_model",
"llm_temperature",
"tts_provider",
"tts_api_key",
"tts_api_url",
"tts_model",
"tts_voice",
"tts_speed",
"asr_provider",
"asr_api_key",
"asr_api_url",
"asr_model",
"asr_interim_interval_ms",
"asr_min_audio_ms",
"asr_start_min_speech_ms",
"asr_pre_speech_ms",
"asr_final_tail_ms",
"duplex_enabled",
"duplex_greeting",
"duplex_system_prompt",
"barge_in_min_duration_ms",
"barge_in_silence_tolerance_ms",
"tools",
}
_BASE_REQUIRED_AGENT_SETTING_KEYS = {
"vad_type",
"vad_model_path",
"vad_threshold",
"vad_min_speech_duration_ms",
"vad_eou_threshold_ms",
"llm_provider",
"llm_model",
"llm_temperature",
"tts_provider",
"tts_voice",
"tts_speed",
"asr_provider",
"asr_interim_interval_ms",
"asr_min_audio_ms",
"asr_start_min_speech_ms",
"asr_pre_speech_ms",
"asr_final_tail_ms",
"duplex_enabled",
"duplex_system_prompt",
"barge_in_min_duration_ms",
"barge_in_silence_tolerance_ms",
}
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
return str(overrides.get(key) or default).strip().lower()
def _is_blank(value: Any) -> bool:
return value is None or (isinstance(value, str) and not value.strip())
@dataclass(frozen=True)
class AgentConfigSelection:
"""Resolved agent config location and how it was selected."""
path: Optional[Path]
source: str
def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]:
"""Parse only agent-related CLI flags from argv."""
config_path: Optional[str] = None
profile: Optional[str] = None
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith("--agent-config="):
config_path = arg.split("=", 1)[1].strip() or None
elif arg == "--agent-config" and i + 1 < len(argv):
config_path = argv[i + 1].strip() or None
i += 1
elif arg.startswith("--agent-profile="):
profile = arg.split("=", 1)[1].strip() or None
elif arg == "--agent-profile" and i + 1 < len(argv):
profile = argv[i + 1].strip() or None
i += 1
i += 1
return config_path, profile
def _agent_config_dir() -> Path:
base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR))
if not base_dir.is_absolute():
base_dir = Path.cwd() / base_dir
return base_dir.resolve()
def _resolve_agent_selection(
agent_config_path: Optional[str] = None,
agent_profile: Optional[str] = None,
argv: Optional[List[str]] = None,
) -> AgentConfigSelection:
cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:]))
path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH")
profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE")
source = "none"
candidate: Optional[Path] = None
if path_value:
source = "cli_path" if (agent_config_path or cli_path) else "env_path"
candidate = Path(path_value)
elif profile_value:
source = "cli_profile" if (agent_profile or cli_profile) else "env_profile"
candidate = _agent_config_dir() / f"{profile_value}.yaml"
else:
fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE
if fallback.exists():
source = "default"
candidate = fallback
if candidate is None:
raise ValueError(
"Agent YAML config is required. Provide --agent-config/--agent-profile "
"or create config/agents/default.yaml."
)
if not candidate.is_absolute():
candidate = (Path.cwd() / candidate).resolve()
else:
candidate = candidate.resolve()
if not candidate.exists():
raise ValueError(f"Agent config file not found ({source}): {candidate}")
if not candidate.is_file():
raise ValueError(f"Agent config path is not a file: {candidate}")
return AgentConfigSelection(path=candidate, source=source)
def _resolve_env_refs(value: Any) -> Any:
"""Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively."""
if isinstance(value, dict):
return {k: _resolve_env_refs(v) for k, v in value.items()}
if isinstance(value, list):
return [_resolve_env_refs(item) for item in value]
if not isinstance(value, str) or "${" not in value:
return value
def _replace(match: re.Match[str]) -> str:
env_key = match.group(1)
default_value = match.group(2)
env_value = os.getenv(env_key)
if env_value is None:
if default_value is None:
raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}")
return default_value
return env_value
return _ENV_REF_PATTERN.sub(_replace, value)
def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize YAML into flat Settings fields."""
normalized: Dict[str, Any] = {}
for key, value in raw.items():
if key == "siliconflow":
raise ValueError(
"Section 'siliconflow' is no longer supported. "
"Move provider-specific fields into agent.llm / agent.asr / agent.tts."
)
if key == "tools":
if not isinstance(value, list):
raise ValueError("Agent config key 'tools' must be a list")
normalized["tools"] = value
continue
section_map = _AGENT_SECTION_KEY_MAP.get(key)
if section_map is None:
normalized[key] = value
continue
if not isinstance(value, dict):
raise ValueError(f"Agent config section '{key}' must be a mapping")
for nested_key, nested_value in value.items():
mapped_key = section_map.get(nested_key)
if mapped_key is None:
raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'")
normalized[mapped_key] = nested_value
unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS)
if unknown_keys:
raise ValueError(
"Unknown agent config keys in YAML: "
+ ", ".join(unknown_keys)
)
return normalized
def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides))
string_required = {
"vad_type",
"vad_model_path",
"llm_provider",
"llm_model",
"tts_provider",
"tts_voice",
"asr_provider",
"duplex_system_prompt",
}
for key in string_required:
if key in overrides and _is_blank(overrides.get(key)):
missing.add(key)
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
missing.add("llm_api_key")
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
missing.add("tts_api_key")
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
missing.add("tts_api_url")
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
missing.add("tts_model")
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
missing.add("asr_api_key")
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
missing.add("asr_api_url")
if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")):
missing.add("asr_model")
return sorted(missing)
def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]:
if yaml is None:
raise RuntimeError(
"PyYAML is required for agent YAML configuration. Install with: pip install pyyaml"
)
with selection.path.open("r", encoding="utf-8") as file:
raw = yaml.safe_load(file) or {}
if not isinstance(raw, dict):
raise ValueError(f"Agent config must be a YAML mapping: {selection.path}")
if "agent" in raw:
agent_value = raw["agent"]
if not isinstance(agent_value, dict):
raise ValueError("The 'agent' key in YAML must be a mapping")
raw = agent_value
resolved = _resolve_env_refs(raw)
overrides = _normalize_agent_overrides(resolved)
missing_required = _missing_required_keys(overrides)
if missing_required:
raise ValueError(
f"Missing required agent settings in YAML ({selection.path}): "
+ ", ".join(missing_required)
)
overrides["agent_config_path"] = str(selection.path)
overrides["agent_config_source"] = selection.source
return overrides
def load_settings(
agent_config_path: Optional[str] = None,
agent_profile: Optional[str] = None,
argv: Optional[List[str]] = None,
) -> "Settings":
"""Load settings from .env and optional agent YAML."""
selection = _resolve_agent_selection(
agent_config_path=agent_config_path,
agent_profile=agent_profile,
argv=argv,
)
agent_overrides = _load_agent_overrides(selection)
return Settings(**agent_overrides)
class Settings(BaseSettings):
@@ -37,30 +388,35 @@ class Settings(BaseSettings):
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
# OpenAI / LLM Configuration
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
# LLM Configuration
llm_provider: str = Field(
default="openai",
description="LLM provider (openai, openai_compatible, siliconflow)"
)
llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key")
llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL")
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
# TTS Configuration
tts_provider: str = Field(
default="openai_compatible",
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
description="TTS provider (edge, openai_compatible, siliconflow)"
)
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
tts_model: Optional[str] = Field(default=None, description="TTS model name")
tts_voice: str = Field(default="anna", description="TTS voice name")
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
# SiliconFlow Configuration
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
# ASR Configuration
asr_provider: str = Field(
default="openai_compatible",
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
description="ASR provider (openai_compatible, buffered, siliconflow)"
)
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key")
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
asr_start_min_speech_ms: int = Field(
@@ -94,6 +450,10 @@ class Settings(BaseSettings):
description="How much silence (ms) is tolerated during potential barge-in before reset"
)
# Optional tool declarations from agent YAML.
# Supports OpenAI function schema style entries and/or shorthand string names.
tools: List[Any] = Field(default_factory=list, description="Default tool definitions for runtime")
# Logging
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(default="json", description="Log format (json or text)")
@@ -118,9 +478,25 @@ class Settings(BaseSettings):
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
# Backend bridge configuration (for call/transcript persistence)
backend_mode: str = Field(
default="auto",
description="Backend integration mode: auto | http | disabled"
)
backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)")
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
history_enabled: bool = Field(default=True, description="Enable history write bridge")
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
history_queue_max_size: int = Field(default=256, description="Max buffered transcript writes per session")
history_retry_max_attempts: int = Field(default=2, description="Retry attempts for each transcript write")
history_retry_backoff_sec: float = Field(default=0.2, description="Base retry backoff for transcript writes")
history_finalize_drain_timeout_sec: float = Field(
default=1.5,
description="Max wait before finalizing history when queue is still draining"
)
# Agent YAML metadata
agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path")
agent_config_source: str = Field(default="none", description="How the agent YAML was selected")
@property
def chunk_size_bytes(self) -> int:
@@ -146,7 +522,7 @@ class Settings(BaseSettings):
# Global settings instance
settings = Settings()
settings = load_settings()
def get_settings() -> Settings:

View File

@@ -20,11 +20,11 @@ except ImportError:
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
from app.config import settings
from app.backend_adapters import build_backend_adapter_from_settings
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session
from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus
from models.ws_v1 import ev
# Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
@@ -54,9 +54,7 @@ async def heartbeat_and_timeout_task(
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try:
await transport.send_event({
**ev("heartbeat"),
})
await session.send_heartbeat()
last_heartbeat_at[0] = now
except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
@@ -78,6 +76,7 @@ app.add_middleware(
# Active sessions storage
active_sessions: Dict[str, Session] = {}
backend_gateway = build_backend_adapter_from_settings()
# Configure logging
logger.remove()
@@ -167,7 +166,7 @@ async def websocket_endpoint(websocket: WebSocket):
# Create transport and session
transport = SocketTransport(websocket)
session = Session(session_id, transport)
session = Session(session_id, transport, backend_gateway=backend_gateway)
active_sessions[session_id] = session
logger.info(f"WebSocket connection established: {session_id}")
@@ -246,7 +245,7 @@ async def webrtc_endpoint(websocket: WebSocket):
# Create transport and session
transport = WebRtcTransport(websocket, pc)
session = Session(session_id, transport)
session = Session(session_id, transport, backend_gateway=backend_gateway)
active_sessions[session_id] = session
logger.info(f"WebRTC connection established: {session_id}")
@@ -360,6 +359,12 @@ async def startup_event():
logger.info(f"Server: {settings.host}:{settings.port}")
logger.info(f"Sample rate: {settings.sample_rate} Hz")
logger.info(f"VAD model: {settings.vad_model_path}")
if settings.agent_config_path:
logger.info(
f"Agent config loaded ({settings.agent_config_source}): {settings.agent_config_path}"
)
else:
logger.info("Agent config: none (using .env/default agent values)")
@app.on_event("shutdown")

View File

@@ -0,0 +1,50 @@
# Agent behavior configuration (safe to edit per profile)
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
# Infra/server/network settings should stay in .env.
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.5
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
# provider: openai | openai_compatible | siliconflow
provider: openai_compatible
model: deepseek-v3
temperature: 0.7
# Required: no fallback. You can still reference env explicitly.
api_key: your_llm_api_key
# Optional for OpenAI-compatible endpoints:
api_url: https://api.qnaigc.com/v1
tts:
# provider: edge | openai_compatible | siliconflow
provider: openai_compatible
api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech
model: FunAudioLLM/CosyVoice2-0.5B
voice: anna
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
provider: openai_compatible
api_key: you_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60

View File

@@ -0,0 +1,73 @@
# Agent behavior configuration with tool declarations.
# This profile is an example only.
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.5
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
# provider: openai | openai_compatible | siliconflow
provider: openai_compatible
model: deepseek-v3
temperature: 0.7
api_key: your_llm_api_key
api_url: https://api.qnaigc.com/v1
tts:
# provider: edge | openai_compatible | siliconflow
provider: openai_compatible
api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech
model: FunAudioLLM/CosyVoice2-0.5B
voice: anna
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
provider: openai_compatible
api_key: your_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: You are a helpful voice assistant with tool-calling support.
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60
# Tool declarations consumed by the engine at startup.
# - String form enables built-in/default tool schema when available.
# - Object form provides OpenAI function schema + executor hint.
tools:
- current_time
- calculator
- name: weather
description: Get weather by city name.
parameters:
type: object
properties:
city:
type: string
description: City name, for example "San Francisco".
required: [city]
executor: server
- name: open_map
description: Open map app on the client device.
parameters:
type: object
properties:
query:
type: string
required: [query]
executor: client

View File

@@ -14,7 +14,8 @@ event-driven design.
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import uuid
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -59,6 +60,12 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_ASR_DELTA_THROTTLE_MS = 300
_LLM_DELTA_THROTTLE_MS = 80
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
"name": "current_time",
@@ -79,7 +86,16 @@ class DuplexPipeline:
tts_service: Optional[BaseTTSService] = None,
asr_service: Optional[BaseASRService] = None,
system_prompt: Optional[str] = None,
greeting: Optional[str] = None
greeting: Optional[str] = None,
knowledge_searcher: Optional[
Callable[..., Awaitable[List[Dict[str, Any]]]]
] = None,
tool_resource_resolver: Optional[
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
] = None,
server_tool_executor: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = None,
):
"""
Initialize duplex pipeline.
@@ -96,6 +112,9 @@ class DuplexPipeline:
self.transport = transport
self.session_id = session_id
self.event_bus = get_event_bus()
self.track_audio_in = self.TRACK_AUDIO_IN
self.track_audio_out = self.TRACK_AUDIO_OUT
self.track_control = self.TRACK_CONTROL
# Initialize VAD
self.vad_model = SileroVAD(
@@ -117,9 +136,14 @@ class DuplexPipeline:
self.llm_service = llm_service
self.tts_service = tts_service
self.asr_service = asr_service # Will be initialized in start()
self._knowledge_searcher = knowledge_searcher
self._tool_resource_resolver = tool_resource_resolver
self._server_tool_executor = server_tool_executor
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
# Conversation manager
self.conversation = ConversationManager(
@@ -153,6 +177,7 @@ class DuplexPipeline:
self._outbound_seq = 0
self._outbound_task: Optional[asyncio.Task] = None
self._drop_outbound_audio = False
self._audio_out_frame_buffer: bytes = b""
# Interruption handling
self._interrupt_event = asyncio.Event()
@@ -181,14 +206,48 @@ class DuplexPipeline:
self._runtime_barge_in_min_duration_ms: Optional[int] = None
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
self._runtime_tools: List[Any] = []
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
self._runtime_tools: List[Any] = list(raw_default_tools)
self._runtime_tool_executor: Dict[str, str] = {}
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
self._next_seq: Optional[Callable[[], int]] = None
self._local_seq: int = 0
# Cross-service correlation IDs
self._turn_count: int = 0
self._response_count: int = 0
self._tts_count: int = 0
self._utterance_count: int = 0
self._current_turn_id: Optional[str] = None
self._current_utterance_id: Optional[str] = None
self._current_response_id: Optional[str] = None
self._current_tts_id: Optional[str] = None
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
self._runtime_tool_executor = self._resolved_tool_executor_map()
if self._server_tool_executor is None:
if self._tool_resource_resolver:
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
return await execute_server_tool(
call,
tool_resource_fetcher=self._tool_resource_resolver,
)
self._server_tool_executor = _executor
else:
self._server_tool_executor = execute_server_tool
logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
"""Use session-scoped monotonic sequence provider for envelope events."""
self._next_seq = provider
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
"""
Apply runtime overrides from WS session.start metadata.
@@ -276,6 +335,136 @@ class DuplexPipeline:
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
llm_base_url = (
self._runtime_llm.get("baseUrl")
or settings.llm_api_url
or self._default_llm_base_url(llm_provider)
)
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
if not output_mode:
output_mode = "audio" if self._tts_output_enabled() else "text"
return {
"output": {"mode": output_mode},
"services": {
"llm": {
"provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": llm_base_url,
},
"asr": {
"provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
"tts": {
"enabled": self._tts_output_enabled(),
"provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
},
},
"tools": {
"allowlist": self._resolved_tool_allowlist(),
},
"tracks": {
"audio_in": self.track_audio_in,
"audio_out": self.track_audio_out,
"control": self.track_control,
},
}
def _next_event_seq(self) -> int:
if self._next_seq:
return self._next_seq()
self._local_seq += 1
return self._local_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
return "asr"
if event_type.startswith("assistant.response."):
return "llm"
if event_type.startswith("assistant.tool_"):
return "tool"
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
return "tts"
return "system"
def _new_id(self, prefix: str, counter: int) -> str:
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
def _start_turn(self) -> str:
self._turn_count += 1
self._current_turn_id = self._new_id("turn", self._turn_count)
self._current_utterance_id = None
self._current_response_id = None
self._current_tts_id = None
return self._current_turn_id
def _start_response(self) -> str:
self._response_count += 1
self._current_response_id = self._new_id("resp", self._response_count)
self._current_tts_id = None
return self._current_response_id
def _start_tts(self) -> str:
self._tts_count += 1
self._current_tts_id = self._new_id("tts", self._tts_count)
return self._current_tts_id
def _finalize_utterance(self) -> str:
if self._current_utterance_id:
return self._current_utterance_id
self._utterance_count += 1
self._current_utterance_id = self._new_id("utt", self._utterance_count)
if not self._current_turn_id:
self._start_turn()
return self._current_utterance_id
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId")
if not track_id:
if source == "asr":
track_id = self.track_audio_in
elif source in {"llm", "tts", "tool"}:
track_id = self.track_audio_out
else:
track_id = self.track_control
data = event.get("data")
if not isinstance(data, dict):
data = {}
if self._current_turn_id:
data.setdefault("turn_id", self._current_turn_id)
if self._current_utterance_id:
data.setdefault("utterance_id", self._current_utterance_id)
if self._current_response_id:
data.setdefault("response_id", self._current_response_id)
if self._current_tts_id:
data.setdefault("tts_id", self._current_tts_id)
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.session_id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
@@ -295,6 +484,18 @@ class DuplexPipeline:
normalized = str(provider or "").strip().lower()
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
@staticmethod
def _is_llm_provider_supported(provider: Any) -> bool:
normalized = str(provider or "").strip().lower()
return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"}
@staticmethod
def _default_llm_base_url(provider: Any) -> Optional[str]:
normalized = str(provider or "").strip().lower()
if normalized == "siliconflow":
return "https://api.siliconflow.cn/v1"
return None
def _tts_output_enabled(self) -> bool:
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
if enabled is not None:
@@ -370,20 +571,25 @@ class DuplexPipeline:
try:
# Connect LLM service
if not self.llm_service:
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key
llm_base_url = (
self._runtime_llm.get("baseUrl")
or settings.llm_api_url
or self._default_llm_base_url(llm_provider)
)
llm_model = self._runtime_llm.get("model") or settings.llm_model
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
if llm_provider == "openai" and llm_api_key:
if self._is_llm_provider_supported(llm_provider) and llm_api_key:
self.llm_service = OpenAILLMService(
api_key=llm_api_key,
base_url=llm_base_url,
model=llm_model,
knowledge_config=self._resolved_knowledge_config(),
knowledge_searcher=self._knowledge_searcher,
)
else:
logger.warning("No OpenAI API key - using mock LLM")
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
self.llm_service = MockLLMService()
if hasattr(self.llm_service, "set_knowledge_config"):
@@ -399,20 +605,22 @@ class DuplexPipeline:
if tts_output_enabled:
if not self.tts_service:
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
tts_model = self._runtime_tts.get("model") or settings.tts_model
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
self.tts_service = OpenAICompatibleTTSService(
api_key=tts_api_key,
api_url=tts_api_url,
voice=tts_voice,
model=tts_model,
model=tts_model or "FunAudioLLM/CosyVoice2-0.5B",
sample_rate=settings.sample_rate,
speed=tts_speed
)
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)")
logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})")
else:
self.tts_service = EdgeTTSService(
voice=tts_voice,
@@ -435,21 +643,23 @@ class DuplexPipeline:
# Connect ASR service
if not self.asr_service:
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
self.asr_service = OpenAICompatibleASRService(
api_key=asr_api_key,
model=asr_model,
api_url=asr_api_url,
model=asr_model or "FunAudioLLM/SenseVoiceSmall",
sample_rate=settings.sample_rate,
interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback
)
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)")
logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})")
else:
self.asr_service = BufferedASRService(
sample_rate=settings.sample_rate
@@ -472,11 +682,13 @@ class DuplexPipeline:
greeting_to_speak = generated_greeting
self.conversation.greeting = generated_greeting
if greeting_to_speak:
self._start_turn()
self._start_response()
await self._send_event(
ev(
"assistant.response.final",
text=greeting_to_speak,
trackId=self.session_id,
trackId=self.track_audio_out,
),
priority=20,
)
@@ -494,10 +706,58 @@ class DuplexPipeline:
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
await self._enqueue_outbound("event", event, priority)
await self._enqueue_outbound("event", self._envelope_event(event), priority)
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
await self._enqueue_outbound("audio", pcm_bytes, priority)
if not pcm_bytes:
return
self._audio_out_frame_buffer += pcm_bytes
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
await self._enqueue_outbound("audio", frame, priority)
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
if not self._audio_out_frame_buffer:
return
tail = self._audio_out_frame_buffer
self._audio_out_frame_buffer = b""
if len(tail) < self._PCM_FRAME_BYTES:
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
await self._enqueue_outbound("audio", tail, priority)
async def _emit_transcript_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"transcript.delta",
trackId=self.track_audio_in,
text=text,
)
},
priority=30,
)
async def _emit_llm_delta(self, text: str) -> None:
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.track_audio_out,
text=text,
)
},
priority=20,
)
async def _flush_pending_llm_delta(self) -> None:
if not self._pending_llm_delta:
return
chunk = self._pending_llm_delta
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
await self._emit_llm_delta(chunk)
async def _outbound_loop(self) -> None:
"""Single sender loop that enforces priority for interrupt events."""
@@ -546,13 +806,13 @@ class DuplexPipeline:
# Emit VAD event
await self.event_bus.publish(event_type, {
"trackId": self.session_id,
"trackId": self.track_audio_in,
"probability": probability
})
await self._send_event(
ev(
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
trackId=self.session_id,
trackId=self.track_audio_in,
probability=probability,
),
priority=30,
@@ -661,6 +921,9 @@ class DuplexPipeline:
# Cancel any current speaking
await self._stop_current_speech()
self._start_turn()
self._finalize_utterance()
# Start new turn
await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
@@ -683,24 +946,45 @@ class DuplexPipeline:
if text == self._last_sent_transcript and not is_final:
return
now_ms = time.monotonic() * 1000.0
self._last_sent_transcript = text
# Send transcript event to client
await self._send_event({
if is_final:
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
{
**ev(
"transcript.final" if is_final else "transcript.delta",
trackId=self.session_id,
"transcript.final",
trackId=self.track_audio_in,
text=text,
)
}, priority=30)
},
priority=30,
)
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
)
if should_emit and self._pending_transcript_delta:
delta = self._pending_transcript_delta
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = now_ms
await self._emit_transcript_delta(delta)
if not is_final:
logger.info(f"[ASR] ASR interim: {text[:100]}")
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
logger.debug(f"Sent transcript (interim): {text[:50]}...")
async def _on_speech_start(self) -> None:
"""Handle user starting to speak."""
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
self._start_turn()
self._finalize_utterance()
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
@@ -779,6 +1063,7 @@ class DuplexPipeline:
return
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
self._finalize_utterance()
# For ASR backends that already emitted final via callback,
# avoid duplicating transcript.final on EOU.
@@ -786,7 +1071,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"transcript.final",
trackId=self.session_id,
trackId=self.track_audio_in,
text=user_text,
)
}, priority=25)
@@ -794,6 +1079,8 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False
self._pending_speech_audio = b""
@@ -881,6 +1168,23 @@ class DuplexPipeline:
result[name] = executor
return result
def _resolved_tool_allowlist(self) -> List[str]:
names: set[str] = set()
for item in self._runtime_tools:
if isinstance(item, str):
name = item.strip()
if name:
names.add(name)
continue
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
names.add(str(fn.get("name")).strip())
elif item.get("name"):
names.add(str(item.get("name")).strip())
return sorted([name for name in names if name])
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
fn = tool_call.get("function")
if isinstance(fn, dict):
@@ -894,6 +1198,44 @@ class DuplexPipeline:
# Default to server execution unless explicitly marked as client.
return "server"
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
fn = tool_call.get("function")
if not isinstance(fn, dict):
return {}
raw = fn.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str) and raw.strip():
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"raw": raw}
except Exception:
return {"raw": raw}
return {}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
status_code = int(status.get("code") or 0) if status else 0
status_message = str(status.get("message") or "") if status else ""
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
tool_name = str(result.get("name") or "unknown_tool")
ok = bool(200 <= status_code < 300)
retryable = status_code >= 500 or status_code in {429, 408}
error: Optional[Dict[str, Any]] = None
if not ok:
error = {
"code": status_code or 500,
"message": status_message or "tool_execution_failed",
"retryable": retryable,
}
return {
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"ok": ok,
"error": error,
"status": {"code": status_code, "message": status_message},
}
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
tool_name = str(result.get("name") or "unknown_tool")
call_id = str(result.get("tool_call_id") or result.get("id") or "")
@@ -904,12 +1246,17 @@ class DuplexPipeline:
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
f"status={status_code} {status_message}".strip()
)
normalized = self._normalize_tool_result(result)
await self._send_event(
{
**ev(
"assistant.tool_result",
trackId=self.session_id,
trackId=self.track_audio_out,
source=source,
tool_call_id=normalized["tool_call_id"],
tool_name=normalized["tool_name"],
ok=normalized["ok"],
error=normalized["error"],
result=result,
)
},
@@ -927,6 +1274,9 @@ class DuplexPipeline:
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
if not call_id:
continue
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
continue
if call_id in self._completed_tool_call_ids:
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
continue
@@ -972,6 +1322,7 @@ class DuplexPipeline:
}
finally:
self._pending_tool_waiters.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
@@ -998,6 +1349,11 @@ class DuplexPipeline:
user_text: User's transcribed text
"""
try:
if not self._current_turn_id:
self._start_turn()
if not self._current_utterance_id:
self._finalize_utterance()
self._start_response()
# Start latency tracking
self._turn_start_time = time.time()
self._first_audio_sent = False
@@ -1012,6 +1368,8 @@ class DuplexPipeline:
self._drop_outbound_audio = False
first_audio_sent = False
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
for _ in range(max_rounds):
if self._interrupt_event.is_set():
break
@@ -1028,6 +1386,7 @@ class DuplexPipeline:
event = self._normalize_stream_event(raw_event)
if event.type == "tool_call":
await self._flush_pending_llm_delta()
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
if not tool_call:
continue
@@ -1045,11 +1404,19 @@ class DuplexPipeline:
f"executor={executor} args={args_preview}"
)
tool_calls.append(enriched_tool_call)
tool_arguments = self._tool_arguments(enriched_tool_call)
if executor == "client" and call_id:
self._pending_client_tool_call_ids.add(call_id)
await self._send_event(
{
**ev(
"assistant.tool_call",
trackId=self.session_id,
trackId=self.track_audio_out,
tool_call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
tool_call=enriched_tool_call,
)
},
@@ -1071,19 +1438,13 @@ class DuplexPipeline:
round_response += text_chunk
sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk)
await self._send_event(
{
**ev(
"assistant.response.delta",
trackId=self.session_id,
text=text_chunk,
)
},
# Keep delta/final on the same event priority so FIFO seq
# preserves stream order (avoid late-delta after final).
priority=20,
)
self._pending_llm_delta += text_chunk
now_ms = time.monotonic() * 1000.0
if (
self._last_llm_delta_emit_ms <= 0.0
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
):
await self._flush_pending_llm_delta()
while True:
split_result = extract_tts_sentence(
@@ -1112,11 +1473,12 @@ class DuplexPipeline:
if self._tts_output_enabled() and not self._interrupt_event.is_set():
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1130,6 +1492,7 @@ class DuplexPipeline:
)
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
await self._flush_pending_llm_delta()
if (
self._tts_output_enabled()
and remaining_text
@@ -1137,11 +1500,12 @@ class DuplexPipeline:
and not self._interrupt_event.is_set()
):
if not first_audio_sent:
self._start_tts()
await self._send_event(
{
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
},
priority=10,
@@ -1172,7 +1536,7 @@ class DuplexPipeline:
try:
result = await asyncio.wait_for(
execute_server_tool(call),
self._server_tool_executor(call),
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
@@ -1204,11 +1568,12 @@ class DuplexPipeline:
]
if full_response and not self._interrupt_event.is_set():
await self._flush_pending_llm_delta()
await self._send_event(
{
**ev(
"assistant.response.final",
trackId=self.session_id,
trackId=self.track_audio_out,
text=full_response,
)
},
@@ -1217,10 +1582,11 @@ class DuplexPipeline:
# Send track end
if first_audio_sent:
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1241,6 +1607,8 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0
self._current_response_id = None
self._current_tts_id = None
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
"""
@@ -1277,7 +1645,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1354,10 +1722,11 @@ class DuplexPipeline:
first_audio_sent = False
# Send track start event
self._start_tts()
await self._send_event({
**ev(
"output.audio.start",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1379,7 +1748,7 @@ class DuplexPipeline:
await self._send_event({
**ev(
"metrics.ttfb",
trackId=self.session_id,
trackId=self.track_audio_out,
latencyMs=round(ttfb_ms),
)
}, priority=25)
@@ -1391,10 +1760,11 @@ class DuplexPipeline:
await asyncio.sleep(0.01)
# Send track end event
await self._flush_audio_out_frames(priority=50)
await self._send_event({
**ev(
"output.audio.end",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=10)
@@ -1422,13 +1792,14 @@ class DuplexPipeline:
self._interrupt_event.set()
self._is_bot_speaking = False
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
# Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self._send_event({
**ev(
"response.interrupted",
trackId=self.session_id,
trackId=self.track_audio_out,
)
}, priority=0)
@@ -1455,6 +1826,7 @@ class DuplexPipeline:
async def _stop_current_speech(self) -> None:
"""Stop any current speech task."""
self._drop_outbound_audio = True
self._audio_out_frame_buffer = b""
if self._current_turn_task and not self._current_turn_task.done():
self._interrupt_event.set()
self._current_turn_task.cancel()

View File

@@ -0,0 +1,244 @@
"""Async history bridge for non-blocking transcript persistence."""
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass
from typing import Any, Optional
from loguru import logger
@dataclass
class _HistoryTranscriptJob:
call_id: str
turn_index: int
speaker: str
content: str
start_ms: int
end_ms: int
duration_ms: int
class SessionHistoryBridge:
"""Session-scoped buffered history writer with background retries."""
_STOP_SENTINEL = object()
def __init__(
self,
*,
history_writer: Any,
enabled: bool,
queue_max_size: int,
retry_max_attempts: int,
retry_backoff_sec: float,
finalize_drain_timeout_sec: float,
):
self._history_writer = history_writer
self._enabled = bool(enabled and history_writer is not None)
self._queue_max_size = max(1, int(queue_max_size))
self._retry_max_attempts = max(0, int(retry_max_attempts))
self._retry_backoff_sec = max(0.0, float(retry_backoff_sec))
self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec))
self._call_id: Optional[str] = None
self._turn_index: int = 0
self._started_mono: Optional[float] = None
self._finalized: bool = False
self._worker_task: Optional[asyncio.Task] = None
self._finalize_lock = asyncio.Lock()
self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size)
@property
def enabled(self) -> bool:
return self._enabled
@property
def call_id(self) -> Optional[str]:
return self._call_id
async def start_call(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str,
) -> Optional[str]:
"""Create remote call record and start background worker."""
if not self._enabled or self._call_id:
return self._call_id
call_id = await self._history_writer.create_call_record(
user_id=user_id,
assistant_id=assistant_id,
source=source,
)
if not call_id:
return None
self._call_id = str(call_id)
self._turn_index = 0
self._finalized = False
self._started_mono = time.monotonic()
self._ensure_worker()
return self._call_id
def elapsed_ms(self) -> int:
if self._started_mono is None:
return 0
return max(0, int((time.monotonic() - self._started_mono) * 1000))
def enqueue_turn(self, *, role: str, text: str) -> bool:
"""Queue one transcript write without blocking the caller."""
if not self._enabled or not self._call_id or self._finalized:
return False
content = str(text or "").strip()
if not content:
return False
speaker = "human" if str(role or "").strip().lower() == "user" else "ai"
end_ms = self.elapsed_ms()
estimated_duration_ms = max(300, min(12000, len(content) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
job = _HistoryTranscriptJob(
call_id=self._call_id,
turn_index=self._turn_index,
speaker=speaker,
content=content,
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._turn_index += 1
self._ensure_worker()
try:
self._queue.put_nowait(job)
return True
except asyncio.QueueFull:
logger.warning(
"History queue full; dropping transcript call_id={} turn={}",
self._call_id,
job.turn_index,
)
return False
async def finalize(self, *, status: str) -> bool:
"""Finalize history record once; waits briefly for queue drain."""
if not self._enabled or not self._call_id:
return False
async with self._finalize_lock:
if self._finalized:
return True
await self._drain_queue()
ok = await self._history_writer.finalize_call_record(
call_id=self._call_id,
status=status,
duration_seconds=self.duration_seconds(),
)
if ok:
self._finalized = True
await self._stop_worker()
return ok
async def shutdown(self) -> None:
"""Stop worker task and release queue resources."""
await self._stop_worker()
def duration_seconds(self) -> int:
if self._started_mono is None:
return 0
return max(0, int(time.monotonic() - self._started_mono))
def _ensure_worker(self) -> None:
if self._worker_task and not self._worker_task.done():
return
self._worker_task = asyncio.create_task(self._worker_loop())
async def _drain_queue(self) -> None:
if self._finalize_drain_timeout_sec <= 0:
return
try:
await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec)
except asyncio.TimeoutError:
logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec)
async def _stop_worker(self) -> None:
task = self._worker_task
if not task:
return
if task.done():
self._worker_task = None
return
sent = False
try:
self._queue.put_nowait(self._STOP_SENTINEL)
sent = True
except asyncio.QueueFull:
pass
if not sent:
try:
await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5)
except asyncio.TimeoutError:
task.cancel()
try:
await asyncio.wait_for(task, timeout=1.5)
except asyncio.TimeoutError:
task.cancel()
try:
await task
except Exception:
pass
except asyncio.CancelledError:
pass
finally:
self._worker_task = None
async def _worker_loop(self) -> None:
while True:
item = await self._queue.get()
try:
if item is self._STOP_SENTINEL:
return
assert isinstance(item, _HistoryTranscriptJob)
await self._write_with_retry(item)
except Exception as exc:
logger.warning("History worker write failed unexpectedly: {}", exc)
finally:
self._queue.task_done()
async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool:
for attempt in range(self._retry_max_attempts + 1):
ok = await self._history_writer.add_transcript(
call_id=job.call_id,
turn_index=job.turn_index,
speaker=job.speaker,
content=job.content,
start_ms=job.start_ms,
end_ms=job.end_ms,
duration_ms=job.duration_ms,
)
if ok:
return True
if attempt >= self._retry_max_attempts:
logger.warning(
"History write dropped after retries call_id={} turn={}",
job.call_id,
job.turn_index,
)
return False
if self._retry_backoff_sec > 0:
await asyncio.sleep(self._retry_backoff_sec * (2**attempt))
return False

View File

@@ -0,0 +1,17 @@
"""Port interfaces for engine-side integration boundaries."""
from core.ports.backend import (
AssistantConfigProvider,
BackendGateway,
HistoryWriter,
KnowledgeSearcher,
ToolResourceResolver,
)
__all__ = [
"AssistantConfigProvider",
"BackendGateway",
"HistoryWriter",
"KnowledgeSearcher",
"ToolResourceResolver",
]

View File

@@ -0,0 +1,84 @@
"""Backend integration ports.
These interfaces define the boundary between engine runtime logic and
backend-side capabilities (config lookup, history persistence, retrieval,
and tool resource discovery).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Protocol
class AssistantConfigProvider(Protocol):
"""Port for loading trusted assistant runtime configuration."""
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
"""Fetch assistant configuration payload."""
class HistoryWriter(Protocol):
"""Port for persisting call and transcript history."""
async def create_call_record(
self,
*,
user_id: int,
assistant_id: Optional[str],
source: str = "debug",
) -> Optional[str]:
"""Create a call record and return backend call ID."""
async def add_transcript(
self,
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> bool:
"""Append one transcript turn segment."""
async def finalize_call_record(
self,
*,
call_id: str,
status: str,
duration_seconds: int,
) -> bool:
"""Finalize a call record."""
class KnowledgeSearcher(Protocol):
"""Port for RAG / knowledge retrieval operations."""
async def search_knowledge_context(
self,
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search a knowledge source and return ranked snippets."""
class ToolResourceResolver(Protocol):
"""Port for resolving tool metadata/configuration."""
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
"""Fetch tool resource configuration."""
class BackendGateway(
AssistantConfigProvider,
HistoryWriter,
KnowledgeSearcher,
ToolResourceResolver,
Protocol,
):
"""Composite backend gateway interface used by engine services."""

View File

@@ -1,22 +1,19 @@
"""Session management for active calls."""
import asyncio
import uuid
import hashlib
import json
import time
import re
import time
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
)
from app.backend_adapters import build_backend_adapter_from_settings
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.history_bridge import SessionHistoryBridge
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
@@ -49,7 +46,39 @@ class Session:
Uses full duplex voice conversation pipeline.
"""
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
_CLIENT_METADATA_OVERRIDES = {
"firstTurnMode",
"greeting",
"generatedOpenerEnabled",
"systemPrompt",
"output",
"bargeIn",
"knowledge",
"knowledgeBaseId",
"history",
"userId",
"assistantId",
"source",
}
_CLIENT_METADATA_ID_KEYS = {
"appId",
"app_id",
"channel",
"configVersionId",
"config_version_id",
}
def __init__(
self,
session_id: str,
transport: BaseTransport,
use_duplex: bool = None,
backend_gateway: Optional[Any] = None,
):
"""
Initialize session.
@@ -61,12 +90,23 @@ class Session:
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
self._history_bridge = SessionHistoryBridge(
history_writer=self._backend_gateway,
enabled=settings.history_enabled,
queue_max_size=settings.history_queue_max_size,
retry_max_attempts=settings.history_retry_max_attempts,
retry_backoff_sec=settings.history_retry_backoff_sec,
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
)
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting
greeting=settings.duplex_greeting,
knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None),
tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None),
)
# Session state
@@ -78,17 +118,15 @@ class Session:
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
self._history_finalized: bool = False
self.current_track_id: str = self.TRACK_CONTROL
self._event_seq: int = 0
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.set_event_sequence_provider(self._next_event_seq)
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
@@ -129,13 +167,47 @@ class Session:
"client",
"Audio received before session.start",
"protocol.order",
stage="protocol",
retryable=False,
)
return
try:
await self.pipeline.process_audio(audio_bytes)
if not audio_bytes:
return
if len(audio_bytes) % 2 != 0:
await self._send_error(
"client",
"Invalid PCM payload: odd number of bytes",
"audio.invalid_pcm",
stage="audio",
retryable=False,
)
return
frame_bytes = self.AUDIO_FRAME_BYTES
if len(audio_bytes) % frame_bytes != 0:
await self._send_error(
"client",
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
"audio.frame_size_mismatch",
stage="audio",
retryable=False,
)
return
for i in range(0, len(audio_bytes), frame_bytes):
frame = audio_bytes[i : i + frame_bytes]
await self.pipeline.process_audio(frame)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
await self._send_error(
"server",
f"Audio processing failed: {e}",
"audio.processing_failed",
stage="audio",
retryable=True,
)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
@@ -176,7 +248,7 @@ class Session:
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
@@ -198,9 +270,9 @@ class Session:
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
auth_payload = message.auth
api_key = auth_payload.apiKey if auth_payload else None
jwt = auth_payload.jwt if auth_payload else None
if settings.ws_api_key:
if api_key != settings.ws_api_key:
@@ -217,10 +289,9 @@ class Session:
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
await self._send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
@@ -231,8 +302,12 @@ class Session:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
raw_metadata = message.metadata or {}
workflow_runtime = self._bootstrap_workflow(raw_metadata)
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
client_runtime = self._sanitize_client_metadata(raw_metadata)
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
metadata = self._merge_runtime_metadata(metadata, client_runtime)
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -248,28 +323,37 @@ class Session:
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
await self._send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
tracks={
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
audio=message.audio.model_dump() if message.audio else {},
)
)
await self._send_event(
ev(
"config.resolved",
trackId=self.TRACK_CONTROL,
config=self._build_config_resolved(metadata),
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
await self._send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
@@ -285,17 +369,24 @@ class Session:
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
await self._send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
async def _send_error(
self,
sender: str,
error_message: str,
code: str,
stage: Optional[str] = None,
retryable: Optional[bool] = None,
track_id: Optional[str] = None,
) -> None:
"""
Send error event to client.
@@ -304,13 +395,26 @@ class Session:
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
resolved_stage = stage or self._infer_error_stage(code)
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
await self._send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
trackId=self.current_track_id,
stage=resolved_stage,
retryable=resolved_retryable,
trackId=resolved_track_id,
data={
"error": {
"stage": resolved_stage,
"code": code,
"message": error_message,
"retryable": resolved_retryable,
}
},
)
)
@@ -329,11 +433,12 @@ class Session:
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self._history_bridge.shutdown()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_call_id:
if self._history_bridge.call_id:
return
history_meta: Dict[str, Any] = {}
@@ -349,7 +454,7 @@ class Session:
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await create_history_call_record(
call_id = await self._history_bridge.start_call(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
@@ -357,10 +462,6 @@ class Session:
if not call_id:
return
self._history_call_id = call_id
self._history_call_started_mono = time.monotonic()
self._history_turn_index = 0
self._history_finalized = False
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
@@ -372,48 +473,11 @@ class Session:
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
return
role = (turn.role or "").lower()
speaker = "human" if role == "user" else "ai"
end_ms = 0
if self._history_call_started_mono is not None:
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
turn_index = self._history_turn_index
await add_history_transcript(
call_id=self._history_call_id,
turn_index=turn_index,
speaker=speaker,
content=turn.text.strip(),
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._history_turn_index += 1
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
if not self._history_call_id or self._history_finalized:
return
duration_seconds = 0
if self._history_call_started_mono is not None:
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
ok = await finalize_history_call_record(
call_id=self._history_call_id,
status=status,
duration_seconds=duration_seconds,
)
if ok:
self._history_finalized = True
await self._history_bridge.finalize(status=status)
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
@@ -483,10 +547,9 @@ class Session:
node = transition.node
edge = transition.edge
await self.transport.send_event(
await self._send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
@@ -494,10 +557,9 @@ class Session:
reason=reason,
)
)
await self.transport.send_event(
await self._send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
@@ -510,10 +572,9 @@ class Session:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
@@ -522,10 +583,9 @@ class Session:
return
if node.node_type == "human_transfer":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
@@ -534,16 +594,77 @@ class Session:
return
if node.node_type == "end":
await self.transport.send_event(
await self._send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
def _next_event_seq(self) -> int:
self._event_seq += 1
return self._event_seq
def _event_source(self, event_type: str) -> str:
if event_type.startswith("workflow."):
return "system"
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
return "system"
if event_type == "error":
return "system"
return "system"
def _infer_error_stage(self, code: str) -> str:
normalized = str(code or "").strip().lower()
if normalized.startswith("audio."):
return "audio"
if normalized.startswith("tool."):
return "tool"
if normalized.startswith("asr."):
return "asr"
if normalized.startswith("llm."):
return "llm"
if normalized.startswith("tts."):
return "tts"
return "protocol"
def _error_track_id(self, stage: str, code: str) -> str:
if stage in {"audio", "asr"}:
return self.TRACK_AUDIO_IN
if stage in {"llm", "tts", "tool"}:
return self.TRACK_AUDIO_OUT
if str(code or "").strip().lower().startswith("auth."):
return self.TRACK_CONTROL
return self.TRACK_CONTROL
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
event_type = str(event.get("type") or "")
source = str(event.get("source") or self._event_source(event_type))
track_id = event.get("trackId") or self.TRACK_CONTROL
data = event.get("data")
if not isinstance(data, dict):
data = {}
for k, v in event.items():
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
continue
data.setdefault(k, v)
event["sessionId"] = self.id
event["seq"] = self._next_event_seq()
event["source"] = source
event["trackId"] = track_id
event["data"] = data
return event
async def _send_event(self, event: Dict[str, Any]) -> None:
await self.transport.send_event(self._envelope_event(event))
async def send_heartbeat(self) -> None:
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
@@ -629,6 +750,137 @@ class Session:
merged[key] = value
return merged
async def _load_server_runtime_metadata(
self,
client_metadata: Dict[str, Any],
workflow_runtime: Dict[str, Any],
) -> Dict[str, Any]:
"""Load trusted runtime metadata from backend assistant config."""
assistant_id = (
workflow_runtime.get("assistantId")
or client_metadata.get("assistantId")
or client_metadata.get("appId")
or client_metadata.get("app_id")
)
if assistant_id is None:
return {}
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
if not callable(provider):
return {}
payload = await provider(str(assistant_id).strip())
if not isinstance(payload, dict):
return {}
assistant_cfg: Dict[str, Any] = {}
session_start_cfg = payload.get("sessionStartMetadata")
if isinstance(session_start_cfg, dict):
assistant_cfg.update(session_start_cfg)
if isinstance(payload.get("assistant"), dict):
assistant_cfg.update(payload.get("assistant"))
elif not assistant_cfg:
assistant_cfg = payload
if not isinstance(assistant_cfg, dict):
return {}
runtime: Dict[str, Any] = {}
passthrough_keys = {
"firstTurnMode",
"generatedOpenerEnabled",
"output",
"bargeIn",
"knowledgeBaseId",
"knowledge",
"history",
"userId",
"source",
"tools",
"services",
"configVersionId",
"config_version_id",
}
for key in passthrough_keys:
if key in assistant_cfg:
runtime[key] = assistant_cfg[key]
if assistant_cfg.get("systemPrompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
elif assistant_cfg.get("prompt") is not None:
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
if assistant_cfg.get("greeting") is not None:
runtime["greeting"] = assistant_cfg.get("greeting")
elif assistant_cfg.get("opener") is not None:
runtime["greeting"] = assistant_cfg.get("opener")
resolved_assistant_id = (
assistant_cfg.get("assistantId")
or payload.get("assistantId")
or assistant_id
)
runtime["assistantId"] = str(resolved_assistant_id)
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
runtime["configVersionId"] = payload.get("configVersionId")
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
runtime["configVersionId"] = payload.get("config_version_id")
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
runtime["configVersionId"] = runtime.get("config_version_id")
return runtime
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize untrusted metadata sources.
This keeps only a small override whitelist and stable config ID fields.
"""
if not isinstance(metadata, dict):
return {}
sanitized: Dict[str, Any] = {}
for key in self._CLIENT_METADATA_ID_KEYS:
if key in metadata:
sanitized[key] = metadata[key]
for key in self._CLIENT_METADATA_OVERRIDES:
if key in metadata:
sanitized[key] = metadata[key]
return sanitized
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Apply client metadata whitelist and remove forbidden secrets."""
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
if isinstance(metadata.get("services"), dict):
logger.warning(
"Session {} provided metadata.services from client; client-side service config is ignored",
self.id,
)
return sanitized
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Build public resolved config payload (secrets removed)."""
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
runtime = self.pipeline.resolved_runtime_config()
return {
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
"channel": metadata.get("channel"),
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
"prompt": {"sha256": prompt_hash},
"output": runtime.get("output", {}),
"services": runtime.get("services", {}),
"tools": runtime.get("tools", {}),
"tracks": {
"audio_in": self.TRACK_AUDIO_IN,
"audio_out": self.TRACK_AUDIO_OUT,
"control": self.TRACK_CONTROL,
},
}
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:

View File

@@ -4,11 +4,13 @@ import asyncio
import ast
import operator
from datetime import datetime
from typing import Any, Dict
from typing import Any, Awaitable, Callable, Dict, Optional
import aiohttp
from app.backend_client import fetch_tool_resource
from app.backend_adapters import build_backend_adapter_from_settings
ToolResourceFetcher = Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
_BIN_OPS = {
ast.Add: operator.add,
@@ -170,11 +172,21 @@ def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
return {}
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
"""Default tool resource resolver via backend adapter."""
adapter = build_backend_adapter_from_settings()
return await adapter.fetch_tool_resource(tool_id)
async def execute_server_tool(
tool_call: Dict[str, Any],
tool_resource_fetcher: Optional[ToolResourceFetcher] = None,
) -> Dict[str, Any]:
"""Execute a server-side tool and return normalized result payload."""
call_id = str(tool_call.get("id") or "").strip()
tool_name = _extract_tool_name(tool_call)
args = _extract_tool_args(tool_call)
resource_fetcher = tool_resource_fetcher or fetch_tool_resource
if tool_name == "calculator":
expression = str(args.get("expression") or "").strip()
@@ -257,7 +269,7 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
}
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
resource = await fetch_tool_resource(tool_name)
resource = await resource_fetcher(tool_name)
if resource and str(resource.get("category") or "") == "query":
method = str(resource.get("http_method") or "GET").strip().upper()
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:

View File

@@ -0,0 +1,47 @@
# Backend Integration and History Bridge
This engine uses adapter-based backend integration so core runtime logic can run
with or without an external backend service.
## Runtime Modes
Configure with environment variables:
- `BACKEND_MODE=auto|http|disabled`
- `BACKEND_URL`
- `BACKEND_TIMEOUT_SEC`
- `HISTORY_ENABLED=true|false`
Mode behavior:
- `auto`: use HTTP backend adapter only when `BACKEND_URL` is set.
- `http`: force HTTP backend adapter (falls back to null adapter when URL is missing).
- `disabled`: force null adapter and run engine-only.
## Architecture
- Ports: `core/ports/backend.py`
- Adapters: `app/backend_adapters.py`
- Compatibility wrappers: `app/backend_client.py`
`Session` and `DuplexPipeline` receive backend capabilities via injected adapter
methods instead of hard-coding backend client imports.
## Async History Writes
Session history persistence is handled by `core/history_bridge.py`.
Design:
- transcript writes are queued with `put_nowait` (non-blocking turn path)
- background worker drains queue
- failed writes retry with exponential backoff
- finalize waits briefly for queue drain before sending call finalize
- finalize is idempotent
Related settings:
- `HISTORY_QUEUE_MAX_SIZE`
- `HISTORY_RETRY_MAX_ATTEMPTS`
- `HISTORY_RETRY_BACKOFF_SEC`
- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC`

View File

@@ -2,6 +2,11 @@
This document defines the public WebSocket protocol for the `/ws` endpoint.
Validation policy:
- WS v1 JSON control messages are validated strictly.
- Unknown top-level fields are rejected for all defined client message types.
- `hello.version` is fixed to `"v1"`.
## Transport
- A single WebSocket connection carries:
@@ -52,43 +57,26 @@ Rules:
"channels": 1
},
"metadata": {
"appId": "assistant_123",
"channel": "web",
"configVersionId": "cfg_20260217_01",
"client": "web-debug",
"output": {
"mode": "audio"
},
"systemPrompt": "You are concise.",
"greeting": "Hi, how can I help?",
"services": {
"llm": {
"provider": "openai",
"model": "gpt-4o-mini",
"apiKey": "sk-...",
"baseUrl": "https://api.openai.com/v1"
},
"asr": {
"provider": "openai_compatible",
"model": "FunAudioLLM/SenseVoiceSmall",
"apiKey": "sf-...",
"interimIntervalMs": 500,
"minAudioMs": 300
},
"tts": {
"enabled": true,
"provider": "openai_compatible",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"apiKey": "sf-...",
"voice": "anna",
"speed": 1.0
}
}
"greeting": "Hi, how can I help?"
}
}
```
`metadata.services` is optional. If omitted, server defaults to environment configuration.
Rules:
- Client-side `metadata.services` is ignored.
- Service config (including secrets) is resolved server-side (env/backend).
- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints).
Text-only mode:
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`.
- Set `metadata.output.mode = "text"`.
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
### `input.text`
@@ -121,6 +109,7 @@ Text-only mode:
### `tool_call.results`
Client tool execution results returned to server.
Only needed when `assistant.tool_call.executor == "client"` (default execution is server-side).
```json
{
@@ -138,21 +127,36 @@ Client tool execution results returned to server.
## Server -> Client Events
All server events include:
All server events include an envelope:
```json
{
"type": "event.name",
"timestamp": 1730000000000
"timestamp": 1730000000000,
"sessionId": "sess_xxx",
"seq": 42,
"source": "asr",
"trackId": "audio_in",
"data": {}
}
```
Envelope notes:
- `seq` is monotonically increasing within one session (for replay/resume).
- `source` is one of: `asr | llm | tts | tool | system | client | server`.
- For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side.
- `data` is structured payload; legacy top-level fields are kept for compatibility.
Common events:
- `hello.ack`
- Fields: `sessionId`, `version`
- `session.started`
- Fields: `sessionId`, `trackId`, `audio`
- Fields: `sessionId`, `trackId`, `tracks`, `audio`
- `config.resolved`
- Fields: `sessionId`, `trackId`, `config`
- Sent immediately after `session.started`.
- Contains effective model/voice/output/tool allowlist/prompt hash, and never includes secrets.
- `session.stopped`
- Fields: `sessionId`, `reason`
- `heartbeat`
@@ -169,9 +173,10 @@ Common events:
- `assistant.response.final`
- Fields: `trackId`, `text`
- `assistant.tool_call`
- Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`)
- Fields: `trackId`, `tool_call`, `tool_call_id`, `tool_name`, `arguments`, `executor`, `timeout_ms`
- `assistant.tool_result`
- Fields: `trackId`, `source`, `result`
- Fields: `trackId`, `source`, `result`, `tool_call_id`, `tool_name`, `ok`, `error`
- `error`: `{ code, message, retryable }` when `ok=false`
- `output.audio.start`
- Fields: `trackId`
- `output.audio.end`
@@ -182,16 +187,54 @@ Common events:
- Fields: `trackId`, `latencyMs`
- `error`
- Fields: `sender`, `code`, `message`, `trackId`
- `trackId` convention:
- `audio_in` for `stage in {audio, asr}`
- `audio_out` for `stage in {llm, tts, tool}`
- `control` otherwise (including protocol/auth errors)
Track IDs (MVP fixed values):
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`)
Correlation IDs (`event.data`):
- `turn_id`: one user-assistant interaction turn.
- `utterance_id`: one ASR final utterance.
- `response_id`: one assistant response generation.
- `tool_call_id`: one tool invocation.
- `tts_id`: one TTS playback segment.
## Binary Audio Frames
After `session.started`, client may send binary PCM chunks continuously.
Recommended format:
- 16-bit signed little-endian PCM.
- 1 channel.
- 16000 Hz.
- 20ms frames (640 bytes) preferred.
MVP fixed format:
- 16-bit signed little-endian PCM (`pcm_s16le`)
- mono (1 channel)
- 16000 Hz
- 20ms frame = 640 bytes
Framing rules:
- Binary audio frame unit is 640 bytes.
- A WS binary message may carry one or multiple complete 640-byte frames.
- Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly).
TTS boundary events:
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.
## Event Throttling
To keep client rendering and server load stable, v1 applies/recommends:
- `transcript.delta`: merge to ~200-500ms cadence (server default: 300ms).
- `assistant.response.delta`: merge to ~50-100ms cadence (server default: 80ms).
- Metrics streams (if enabled beyond `metrics.ttfb`): emit every ~500-1000ms.
## Error Structure
`error` keeps legacy top-level fields (`code`, `message`) and adds structured info:
- `stage`: `protocol | asr | llm | tts | tool | audio`
- `retryable`: boolean
- `data.error`: `{ stage, code, message, retryable }`
## Compatibility

View File

@@ -0,0 +1,520 @@
# WS v1 协议完整说明(中文)
本文档描述 `/ws` 端点的 WebSocket v1 协议,覆盖:
- 客户端输入JSON 文本消息 + 二进制音频);
- 服务端输出JSON 事件 + 二进制音频);
- 每个参数的类型、约束、含义与使用方式;
- 握手顺序、状态机、错误语义与实现细节。
实现对照来源:
- `models/ws_v1.py`
- `core/session.py`
- `core/duplex_pipeline.py`
- `app/main.py`
---
## 1. 传输与基础规则
- 连接地址:`ws://<host>/ws`
- 单连接双通道承载:
- 文本帧JSON 控制消息(严格校验 schema
- 二进制帧:原始 PCM 音频
- JSON 校验策略:
- 所有已定义客户端消息都 `extra="forbid"`,即不允许未声明字段;
- `hello.version` 固定必须是 `"v1"`
- 缺失 `type` 或未知 `type` 会返回协议错误。
---
## 2. 状态机与消息顺序
### 2.1 服务端状态
- `WAIT_HELLO`:等待 `hello`
- `WAIT_START`:已通过握手,等待 `session.start`
- `ACTIVE`:会话运行中,可收发文本/音频
- `STOPPED`:会话结束
### 2.2 正确顺序
1. 客户端发送 `hello`
2. 服务端返回 `hello.ack`
3. 客户端发送 `session.start`
4. 服务端返回 `session.started`
5. 客户端可持续发送:
- 二进制音频
- `input.text`(可选)
- `response.cancel`(可选)
- `tool_call.results`(可选)
6. 客户端发送 `session.stop` 或直接断开连接
顺序错误会返回 `error``code = "protocol.order"`
---
## 3. 客户端 -> 服务端消息(输入)
## 3.1 `hello`
示例:
```json
{
"type": "hello",
"version": "v1",
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 |
| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 |
| `auth` | object \| null | 否 | 仅允许 `apiKey``jwt` | 认证载荷 | 认证策略由服务端配置决定 |
| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 |
| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 |
认证行为:
- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。
-`WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY``auth.apiKey``auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。
## 3.2 `session.start`
示例:
```json
{
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": 16000,
"channels": 1
},
"metadata": {
"appId": "assistant_123",
"channel": "web",
"configVersionId": "cfg_20260217_01",
"client": "web-debug",
"output": {
"mode": "audio"
},
"systemPrompt": "你是简洁助手",
"greeting": "你好,我能帮你什么?"
}
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"session.start"` | 启动会话 | 握手后第二阶段消息 |
| `audio` | object \| null | 否 | 仅支持固定值 | 音频格式描述 | 仅用于声明MVP 实际只接受固定 PCM |
| `audio.encoding` | string | 否 | 固定 `"pcm_s16le"` | 编码格式 | 非该值会在模型校验层报错 |
| `audio.sample_rate_hz` | number | 否 | 固定 `16000` | 采样率 | 16kHz |
| `audio.channels` | number | 否 | 固定 `1` | 声道数 | 单声道 |
| `metadata` | object \| null | 否 | 任意对象(会被白名单过滤) | 运行时配置 | 用于 app/channel/提示词/输出模式等覆盖 |
`metadata` 白名单策略(关键):
- 允许透传的标识字段ID 类):
- `appId` / `app_id`
- `channel`
- `configVersionId` / `config_version_id`
- 允许透传的覆盖字段:
- `firstTurnMode`
- `greeting`
- `generatedOpenerEnabled`
- `systemPrompt`
- `output`
- `bargeIn`
- `knowledge`
- `knowledgeBaseId`
- `history`
- `userId`
- `assistantId`
- `source`
- 客户端传入 `metadata.services` 会被忽略(服务端会记录 warning服务配置由后端/环境变量决定。
`output.mode` 用法:
- `"audio"`(默认语音输出)
- `"text"`(纯文本输出)
- 纯文本模式下仍会收到 `assistant.response.delta/final`
- 不会收到 TTS 音频帧与 `output.audio.start/end`
## 3.3 `input.text`
示例:
```json
{
"type": "input.text",
"text": "你能做什么?"
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"input.text"` | 文本输入 | 跳过 ASR直接触发 LLM 回答 |
| `text` | string | 是 | 非空字符串为佳 | 用户文本 | 用于文本聊天或调试 |
## 3.4 `response.cancel`
示例:
```json
{
"type": "response.cancel",
"graceful": false
}
```
字段说明:
| 字段 | 类型 | 必填 | 默认值 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 |
| `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 |
## 3.5 `tool_call.results`
仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。
示例:
```json
{
"type": "tool_call.results",
"results": [
{
"tool_call_id": "call_abc123",
"name": "weather",
"output": { "temp_c": 21, "condition": "sunny" },
"status": { "code": 200, "message": "ok" }
}
]
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"tool_call.results"` | 工具执行回传 | 客户端工具结果上送 |
| `results` | array | 否 | 默认为空数组 | 多个工具结果 | 可批量回传 |
| `results[].tool_call_id` | string | 是 | 任意字符串 | 工具调用ID | 必须与 `assistant.tool_call.tool_call_id` 对应 |
| `results[].name` | string | 是 | 任意字符串 | 工具名 | 建议与请求一致 |
| `results[].output` | any | 否 | 任意 JSON | 工具输出 | 供模型后续组织回答 |
| `results[].status` | object | 是 | 包含 `code``message` | 执行状态 | 用于判定成功/失败 |
| `results[].status.code` | number | 是 | HTTP 风格状态码 | 状态码 | `200-299` 判定成功 |
| `results[].status.message` | string | 是 | 任意字符串 | 状态描述 | 例如 `"ok"` / `"timeout"` |
处理规则:
- 未请求过的 `tool_call_id` 会被忽略(防止伪造/串话);
- 重复回传会被忽略;
- 超时未回传会由服务端合成超时结果(`504`)。
## 3.6 `session.stop`
示例:
```json
{
"type": "session.stop",
"reason": "client_disconnect"
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"session.stop"` | 结束会话 | 正常结束推荐发送 |
| `reason` | string \| null | 否 | 任意字符串 | 结束原因 | 服务端会回传到 `session.stopped.reason` |
---
## 4. 二进制音频输入(客户端 -> 服务端)
`session.started` 之后可持续发送二进制音频。
固定格式MVP
- 编码:`pcm_s16le`
- 采样率:`16000`
- 声道:`1`
- 帧长20ms = `640 bytes`
分包规则:
- 单个 WebSocket 二进制消息可包含 1 帧或多帧;
- 长度必须是 `640` 的整数倍;
- 不是 `640` 倍数会触发 `audio.frame_size_mismatch`,该消息整包丢弃;
- 奇数字节长度会触发 `audio.invalid_pcm`
---
## 5. 服务端 -> 客户端事件(输出)
所有 JSON 事件都包含统一包络字段。
## 5.1 统一包络Envelope
```json
{
"type": "event.name",
"timestamp": 1730000000000,
"sessionId": "sess_xxx",
"seq": 42,
"source": "asr",
"trackId": "audio_in",
"data": {}
}
```
字段说明:
| 字段 | 类型 | 含义 | 使用说明 |
|---|---|---|---|
| `type` | string | 事件类型 | 见下方事件清单 |
| `timestamp` | number | 事件时间戳(毫秒) | 由 `ev()` 生成 |
| `sessionId` | string | 会话ID | 同一连接固定 |
| `seq` | number | 单会话递增序号 | 可用于重放、去重、排序 |
| `source` | string | 事件来源 | 常见:`asr`/`llm`/`tts`/`tool`/`system`/`client`/`server` |
| `trackId` | string | 事件轨道 | 常用:`audio_in`/`audio_out`/`control` |
| `data` | object | 结构化数据 | 顶层业务字段会镜像进 `data` 以兼容旧客户端 |
关联ID`data` 内自动注入,存在时):
- `turn_id`:一次用户-助手对话轮次
- `utterance_id`:一次用户语音话语
- `response_id`:一次助手生成响应
- `tool_call_id`:一次工具调用
- `tts_id`:一次 TTS 播放段
## 5.2 事件类型与参数
### 5.2.1 会话与控制类
1. `hello.ack`
- 关键字段:`version`
- 含义:握手成功,应紧接着发送 `session.start`
2. `session.started`
- 关键字段:
- `trackId`
- `tracks.audio_in`
- `tracks.audio_out`
- `tracks.control`
- `audio`(回显客户端声明的音频元信息)
- 含义:会话进入 ACTIVE可发音频/文本
3. `config.resolved`
- 关键字段:
- `config.appId`
- `config.channel`
- `config.configVersionId`
- `config.prompt.sha256`
- `config.output`
- `config.services`(去密钥后的有效服务配置)
- `config.tools.allowlist`
- `config.tracks`
- 含义:服务端最终生效配置快照,便于前端展示与排错
4. `heartbeat`
- 关键字段:无业务字段(仅 envelope
- 含义:保活心跳
- 默认间隔:`heartbeat_interval_sec`(默认 50s
5. `session.stopped`
- 关键字段:`reason`
- 含义:会话结束确认
6. `error`
- 关键字段:
- `sender`
- `code`
- `message`
- `stage`
- `retryable`
- `trackId`
- `data.error`(结构化错误镜像)
- 含义:统一错误事件
### 5.2.2 识别与输入侧ASR/VAD
1. `input.speech_started`
- 字段:`probability`
- 含义:检测到语音开始
2. `input.speech_stopped`
- 字段:`probability`
- 含义:检测到语音结束
3. `transcript.delta`
- 字段:`text`
- 含义ASR 增量识别文本(节流发送)
4. `transcript.final`
- 字段:`text`
- 含义ASR 最终识别文本
### 5.2.3 输出侧LLM/TTS/Tool
1. `assistant.response.delta`
- 字段:`text`
- 含义:助手增量文本输出(节流发送)
2. `assistant.response.final`
- 字段:`text`
- 含义:助手完整文本输出
3. `assistant.tool_call`
- 字段:
- `tool_call_id`
- `tool_name`
- `arguments`(对象)
- `executor``client``server`
- `timeout_ms`
- `tool_call`(完整工具调用对象)
- 含义:通知客户端发生工具调用(用于可视化或客户端执行)
4. `assistant.tool_result`
- 字段:
- `source``client``server`
- `tool_call_id`
- `tool_name`
- `ok`boolean
- `error`(失败时 `{code,message,retryable}`
- `result`(原始结果对象)
- 含义:工具调用结果回执
5. `output.audio.start`
- 含义TTS 音频输出开始边界
6. `output.audio.end`
- 含义TTS 音频输出结束边界
7. `response.interrupted`
- 含义当前回答被打断barge-in 或 cancel
8. `metrics.ttfb`
- 字段:`latencyMs`
- 含义首包音频时延TTFB
### 5.2.4 工作流扩展事件(可选)
`metadata.workflow` 生效,会额外出现:
- `workflow.started`
- `workflow.node.entered`
- `workflow.edge.taken`
- `workflow.tool.requested`
- `workflow.human_transfer`
- `workflow.ended`
这些事件用于外部可视化工作流状态,不影响基础语音会话协议。
---
## 6. 服务端二进制音频输出(服务端 -> 客户端)
- 音频为 PCM 二进制帧;
- 发送单位对齐到 `640 bytes`(不足会补零后发送);
- 前端通常结合 `output.audio.start/end` 做播放边界控制;
- 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。
---
## 7. 错误模型与常见错误码
统一结构(`error` 事件):
```json
{
"type": "error",
"sender": "client",
"code": "protocol.invalid_message",
"message": "Invalid message: ...",
"stage": "protocol",
"retryable": false,
"trackId": "control",
"data": {
"error": {
"stage": "protocol",
"code": "protocol.invalid_message",
"message": "Invalid message: ...",
"retryable": false
}
}
}
```
字段语义:
- `sender`:错误来源角色(如 `client` / `server` / `auth`
- `code`:机器可读错误码
- `message`:人类可读描述
- `stage`:阶段(`protocol|audio|asr|llm|tts|tool`
- `retryable`:是否建议重试
- `trackId`:错误归属轨道
常见错误码:
- `protocol.invalid_json`
- `protocol.invalid_message`
- `protocol.order`
- `protocol.version_unsupported`
- `protocol.unsupported`
- `auth.invalid_api_key`
- `auth.required`
- `audio.invalid_pcm`
- `audio.frame_size_mismatch`
- `audio.processing_failed`
- `server.internal`
---
## 8. 心跳与超时
服务端后台任务逻辑:
- 每隔约 5 秒检查一次连接;
- 超过 `inactivity_timeout_sec`(默认 60 秒)未收到任何客户端消息则关闭会话;
- 每隔 `heartbeat_interval_sec`(默认 50 秒)发送一次 `heartbeat`
客户端建议:
- 持续上行音频或定期发送轻量文本消息,避免被判定闲置;
-`heartbeat` + `seq` 检测连接活性和事件乱序。
---
## 9. 实战接入建议
1. 建连后立即发送 `hello`,收到 `hello.ack` 后再发 `session.start`
2. 语音输入严格按 16k/16bit/mono并保证每个 WS 二进制消息长度是 `640*n`
3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。
4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。
5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`
6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`
---
## 10. 最小完整时序示例
```text
Client -> hello
Server <- hello.ack
Client -> session.start
Server <- session.started
Server <- config.resolved
Client -> (binary pcm frames...)
Server <- input.speech_started / transcript.delta / transcript.final
Server <- assistant.response.delta / assistant.response.final
Server <- output.audio.start
Server <- (binary pcm frames...)
Server <- output.audio.end
Client -> session.stop
Server <- session.stopped
```

View File

@@ -59,8 +59,12 @@ class MicrophoneClient:
url: str,
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
app_id: str = "assistant_demo",
channel: str = "mic_client",
config_version_id: str = "local-dev",
input_device: int = None,
output_device: int = None
output_device: int = None,
track_debug: bool = False,
):
"""
Initialize microphone client.
@@ -76,8 +80,12 @@ class MicrophoneClient:
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.input_device = input_device
self.output_device = output_device
self.track_debug = track_debug
# WebSocket connection
self.ws = None
@@ -107,6 +115,17 @@ class MicrophoneClient:
# Verbose mode for streaming LLM responses
self.verbose = False
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None:
"""Connect to WebSocket server."""
print(f"Connecting to {self.url}...")
@@ -114,20 +133,30 @@ class MicrophoneClient:
self.running = True
print("Connected!")
# Send invite command
# WS v1 handshake: hello -> session.start
await self.send_command({
"command": "invite",
"option": {
"codec": "pcm",
"sampleRate": self.sample_rate
}
"type": "hello",
"version": "v1",
})
await self.send_command({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1,
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
})
async def send_command(self, cmd: dict) -> None:
"""Send JSON command to server."""
if self.ws:
await self.ws.send(json.dumps(cmd))
print(f"→ Command: {cmd.get('command', 'unknown')}")
print(f"→ Command: {cmd.get('type', 'unknown')}")
async def send_chat(self, text: str) -> None:
"""Send chat message (text input)."""
@@ -136,7 +165,7 @@ class MicrophoneClient:
self.first_audio_received = False
await self.send_command({
"command": "chat",
"type": "input.text",
"text": text
})
print(f"→ Chat: {text}")
@@ -144,13 +173,14 @@ class MicrophoneClient:
async def send_interrupt(self) -> None:
"""Send interrupt command."""
await self.send_command({
"command": "interrupt"
"type": "response.cancel",
"graceful": False,
})
async def send_hangup(self, reason: str = "User quit") -> None:
"""Send hangup command."""
await self.send_command({
"command": "hangup",
"type": "session.stop",
"reason": reason
})
@@ -295,43 +325,48 @@ class MicrophoneClient:
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("event", "unknown")
event_type = event.get("type", event.get("event", "unknown"))
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "answer":
print("← Session ready!")
elif event_type == "speaking":
print("User speech detected")
elif event_type == "silence":
print("← User silence detected")
elif event_type == "transcript":
if event_type in {"hello.ack", "session.started"}:
print(f"← Session ready!{ids}")
elif event_type == "config.resolved":
print(f"Config resolved: {event.get('config', {}).get('output', {})}{ids}")
elif event_type == "input.speech_started":
print(f"← User speech detected{ids}")
elif event_type == "input.speech_stopped":
print(f"← User silence detected{ids}")
elif event_type in {"transcript", "transcript.delta", "transcript.final"}:
# Display user speech transcription
text = event.get("text", "")
is_final = event.get("isFinal", False)
is_final = event_type == "transcript.final" or bool(event.get("isFinal"))
if is_final:
# Clear the interim line and print final
print(" " * 80, end="\r") # Clear previous interim text
print(f"→ You: {text}")
print(f"→ You: {text}{ids}")
else:
# Interim result - show with indicator (overwrite same line)
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [listening] {display_text}".ljust(80), end="\r")
elif event_type == "ttfb":
elif event_type in {"ttfb", "metrics.ttfb"}:
# Server-side TTFB event
latency_ms = event.get("latencyMs", 0)
print(f"← [TTFB] Server reported latency: {latency_ms}ms")
elif event_type == "llmResponse":
elif event_type in {"llmResponse", "assistant.response.delta", "assistant.response.final"}:
# LLM text response
text = event.get("text", "")
is_final = event.get("isFinal", False)
is_final = event_type == "assistant.response.final" or bool(event.get("isFinal"))
if is_final:
# Print final LLM response
print(f"← AI: {text}")
elif self.verbose:
# Show streaming chunks only in verbose mode
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [streaming] {display_text}")
elif event_type == "trackStart":
print("← Bot started speaking")
print(f" [streaming] {display_text}{ids}")
elif event_type in {"trackStart", "output.audio.start"}:
print(f"← Bot started speaking{ids}")
# IMPORTANT: Accept audio again after trackStart
self._discard_audio = False
self._audio_sequence += 1
@@ -342,13 +377,13 @@ class MicrophoneClient:
# Clear any old audio in buffer
with self.audio_output_lock:
self.audio_output_buffer = b""
elif event_type == "trackEnd":
print("← Bot finished speaking")
elif event_type in {"trackEnd", "output.audio.end"}:
print(f"← Bot finished speaking{ids}")
# Reset TTFB tracking after response completes
self.request_start_time = None
self.first_audio_received = False
elif event_type == "interrupt":
print("← Bot interrupted!")
elif event_type in {"interrupt", "response.interrupted"}:
print(f"← Bot interrupted!{ids}")
# IMPORTANT: Discard all audio until next trackStart
self._discard_audio = True
# Clear audio buffer immediately
@@ -357,12 +392,12 @@ class MicrophoneClient:
self.audio_output_buffer = b""
print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)")
elif event_type == "error":
print(f"← Error: {event.get('error')}")
elif event_type == "hangup":
print(f"← Hangup: {event.get('reason')}")
print(f"← Error: {event.get('error')}{ids}")
elif event_type in {"hangup", "session.stopped"}:
print(f"← Hangup: {event.get('reason')}{ids}")
self.running = False
else:
print(f"← Event: {event_type}")
print(f"← Event: {event_type}{ids}")
async def interactive_mode(self) -> None:
"""Run interactive mode for text chat."""
@@ -573,6 +608,26 @@ async def main():
action="store_true",
help="Show streaming LLM response chunks"
)
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="mic_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
args = parser.parse_args()
@@ -583,8 +638,12 @@ async def main():
client = MicrophoneClient(
url=args.url,
sample_rate=args.sample_rate,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
input_device=args.input_device,
output_device=args.output_device
output_device=args.output_device,
track_debug=args.track_debug,
)
client.verbose = args.verbose

View File

@@ -52,9 +52,21 @@ if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
class SimpleVoiceClient:
"""Simple voice client with reliable audio playback."""
def __init__(self, url: str, sample_rate: int = 16000):
def __init__(
self,
url: str,
sample_rate: int = 16000,
app_id: str = "assistant_demo",
channel: str = "simple_client",
config_version_id: str = "local-dev",
track_debug: bool = False,
):
self.url = url
self.sample_rate = sample_rate
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.track_debug = track_debug
self.ws = None
self.running = False
@@ -76,6 +88,17 @@ class SimpleVoiceClient:
# Interrupt handling - discard audio until next trackStart
self._discard_audio = False
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self):
"""Connect to server."""
print(f"Connecting to {self.url}...")
@@ -83,12 +106,25 @@ class SimpleVoiceClient:
self.running = True
print("Connected!")
# Send invite
# WS v1 handshake: hello -> session.start
await self.ws.send(json.dumps({
"command": "invite",
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
"type": "hello",
"version": "v1",
}))
print("-> invite")
await self.ws.send(json.dumps({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1,
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
}))
print("-> hello/session.start")
async def send_chat(self, text: str):
"""Send chat message."""
@@ -96,8 +132,8 @@ class SimpleVoiceClient:
self.request_start_time = time.time()
self.first_audio_received = False
await self.ws.send(json.dumps({"command": "chat", "text": text}))
print(f"-> chat: {text}")
await self.ws.send(json.dumps({"type": "input.text", "text": text}))
print(f"-> input.text: {text}")
def play_audio(self, audio_data: bytes):
"""Play audio data immediately."""
@@ -152,34 +188,39 @@ class SimpleVoiceClient:
else:
# JSON event
event = json.loads(msg)
etype = event.get("event", "?")
etype = event.get("type", event.get("event", "?"))
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={etype} trackId={event.get('trackId')}{ids}")
if etype == "transcript":
if etype in {"transcript", "transcript.delta", "transcript.final"}:
# User speech transcription
text = event.get("text", "")
is_final = event.get("isFinal", False)
is_final = etype == "transcript.final" or bool(event.get("isFinal"))
if is_final:
print(f"<- You said: {text}")
print(f"<- You said: {text}{ids}")
else:
print(f"<- [listening] {text}", end="\r")
elif etype == "ttfb":
elif etype in {"ttfb", "metrics.ttfb"}:
# Server-side TTFB event
latency_ms = event.get("latencyMs", 0)
print(f"<- [TTFB] Server reported latency: {latency_ms}ms")
elif etype == "trackStart":
elif etype in {"trackStart", "output.audio.start"}:
# New track starting - accept audio again
self._discard_audio = False
print(f"<- {etype}")
elif etype == "interrupt":
print(f"<- {etype}{ids}")
elif etype in {"interrupt", "response.interrupted"}:
# Interrupt - discard audio until next trackStart
self._discard_audio = True
print(f"<- {etype} (discarding audio until new track)")
elif etype == "hangup":
print(f"<- {etype}")
print(f"<- {etype}{ids} (discarding audio until new track)")
elif etype in {"hangup", "session.stopped"}:
print(f"<- {etype}{ids}")
self.running = False
break
elif etype == "config.resolved":
print(f"<- config.resolved {event.get('config', {}).get('output', {})}{ids}")
else:
print(f"<- {etype}")
print(f"<- {etype}{ids}")
except asyncio.TimeoutError:
continue
@@ -270,6 +311,10 @@ async def main():
parser.add_argument("--text", help="Send text and play response")
parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--sample-rate", type=int, default=16000)
parser.add_argument("--app-id", default="assistant_demo")
parser.add_argument("--channel", default="simple_client")
parser.add_argument("--config-version-id", default="local-dev")
parser.add_argument("--track-debug", action="store_true")
args = parser.parse_args()
@@ -277,7 +322,14 @@ async def main():
list_audio_devices()
return
client = SimpleVoiceClient(args.url, args.sample_rate)
client = SimpleVoiceClient(
args.url,
args.sample_rate,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
track_debug=args.track_debug,
)
await client.run(args.text)

View File

@@ -36,8 +36,18 @@ def generate_sine_wave(duration_ms=1000):
return audio_data
async def receive_loop(ws, ready_event: asyncio.Event):
async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False):
"""Listen for incoming messages from the server."""
def event_ids_suffix(data):
payload = data.get("data") if isinstance(data.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = payload.get(key, data.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
print("👂 Listening for server responses...")
async for msg in ws:
timestamp = datetime.now().strftime("%H:%M:%S")
@@ -46,7 +56,10 @@ async def receive_loop(ws, ready_event: asyncio.Event):
try:
data = json.loads(msg.data)
event_type = data.get('type', 'Unknown')
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
ids = event_ids_suffix(data)
print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...")
if track_debug:
print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}")
if event_type == "session.started":
ready_event.set()
except json.JSONDecodeError:
@@ -113,7 +126,7 @@ async def send_sine_loop(ws):
print("\n✅ Finished streaming test audio.")
async def run_client(url, file_path=None, use_sine=False):
async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False):
"""Run the WebSocket test client."""
session = aiohttp.ClientSession()
try:
@@ -121,7 +134,7 @@ async def run_client(url, file_path=None, use_sine=False):
async with session.ws_connect(url) as ws:
print("✅ Connected!")
session_ready = asyncio.Event()
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
# Send v1 hello + session.start handshake
await ws.send_json({"type": "hello", "version": "v1"})
@@ -131,7 +144,12 @@ async def run_client(url, file_path=None, use_sine=False):
"encoding": "pcm_s16le",
"sample_rate_hz": SAMPLE_RATE,
"channels": 1
}
},
"metadata": {
"appId": "assistant_demo",
"channel": "test_websocket",
"configVersionId": "local-dev",
},
})
print("📤 Sent v1 hello/session.start")
await asyncio.wait_for(session_ready.wait(), timeout=8)
@@ -168,9 +186,10 @@ if __name__ == "__main__":
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging")
args = parser.parse_args()
try:
asyncio.run(run_client(args.url, args.file, args.sine))
asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug))
except KeyboardInterrupt:
print("\n👋 Client stopped.")

View File

@@ -57,10 +57,15 @@ class WavFileClient:
url: str,
input_file: str,
output_file: str,
app_id: str = "assistant_demo",
channel: str = "wav_client",
config_version_id: str = "local-dev",
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
wait_time: float = 15.0,
verbose: bool = False
verbose: bool = False,
track_debug: bool = False,
tail_silence_ms: int = 800,
):
"""
Initialize WAV file client.
@@ -77,11 +82,17 @@ class WavFileClient:
self.url = url
self.input_file = Path(input_file)
self.output_file = Path(output_file)
self.app_id = app_id
self.channel = channel
self.config_version_id = config_version_id
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.wait_time = wait_time
self.verbose = verbose
self.track_debug = track_debug
self.tail_silence_ms = max(0, int(tail_silence_ms))
self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms
# WebSocket connection
self.ws = None
@@ -105,6 +116,7 @@ class WavFileClient:
self.track_started = False
self.track_ended = False
self.send_completed = False
self.session_ready = False
# Events log
self.events_log = []
@@ -125,6 +137,17 @@ class WavFileClient:
safe_message = message.encode('ascii', errors='replace').decode('ascii')
print(f"{direction} {safe_message}")
@staticmethod
def _event_ids_suffix(event: dict) -> str:
data = event.get("data") if isinstance(event.get("data"), dict) else {}
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
parts = []
for key in keys:
value = data.get(key, event.get(key))
if value:
parts.append(f"{key}={value}")
return f" [{' '.join(parts)}]" if parts else ""
async def connect(self) -> None:
"""Connect to WebSocket server."""
self.log_event("", f"Connecting to {self.url}...")
@@ -132,25 +155,35 @@ class WavFileClient:
self.running = True
self.log_event("", "Connected!")
# Send invite command
# WS v1 handshake: hello -> session.start
await self.send_command({
"command": "invite",
"option": {
"codec": "pcm",
"sampleRate": self.sample_rate
}
"type": "hello",
"version": "v1",
})
await self.send_command({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": self.sample_rate,
"channels": 1
},
"metadata": {
"appId": self.app_id,
"channel": self.channel,
"configVersionId": self.config_version_id,
},
})
async def send_command(self, cmd: dict) -> None:
"""Send JSON command to server."""
if self.ws:
await self.ws.send(json.dumps(cmd))
self.log_event("", f"Command: {cmd.get('command', 'unknown')}")
self.log_event("", f"Command: {cmd.get('type', 'unknown')}")
async def send_hangup(self, reason: str = "Session complete") -> None:
"""Send hangup command."""
await self.send_command({
"command": "hangup",
"type": "session.stop",
"reason": reason
})
@@ -210,6 +243,10 @@ class WavFileClient:
end_sample = min(sent_samples + chunk_size, total_samples)
chunk = audio_data[sent_samples:end_sample]
chunk_bytes = chunk.tobytes()
if len(chunk_bytes) % self.frame_bytes != 0:
# v1 audio framing requires 640-byte (20ms) PCM units.
pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes)
chunk_bytes += b"\x00" * pad
# Send to server
if self.ws:
@@ -227,6 +264,16 @@ class WavFileClient:
# Server expects audio at real-time pace for VAD/ASR to work properly
await asyncio.sleep(self.chunk_duration_ms / 1000)
# Add a short silence tail to help VAD/EOU close the final utterance.
if self.tail_silence_ms > 0 and self.ws:
tail_frames = max(1, self.tail_silence_ms // 20)
silence = b"\x00" * self.frame_bytes
for _ in range(tail_frames):
await self.ws.send(silence)
self.bytes_sent += len(silence)
await asyncio.sleep(0.02)
self.log_event("", f"Sent trailing silence: {self.tail_silence_ms}ms")
self.send_completed = True
elapsed = time.time() - self.send_start_time
self.log_event("", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)")
@@ -277,54 +324,59 @@ class WavFileClient:
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("event", "unknown")
event_type = event.get("type", "unknown")
ids = self._event_ids_suffix(event)
if self.track_debug:
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
if event_type == "answer":
self.log_event("", "Session ready!")
elif event_type == "speaking":
self.log_event("", "Speech detected")
elif event_type == "silence":
self.log_event("", "Silence detected")
elif event_type == "transcript":
# ASR transcript (interim = asrDelta-style, final = asrFinal-style)
if event_type == "hello.ack":
self.log_event("", f"Handshake acknowledged{ids}")
elif event_type == "session.started":
self.session_ready = True
self.log_event("", f"Session ready!{ids}")
elif event_type == "config.resolved":
config = event.get("config", {})
self.log_event("", f"Config resolved (output={config.get('output', {})}){ids}")
elif event_type == "input.speech_started":
self.log_event("", f"Speech detected{ids}")
elif event_type == "input.speech_stopped":
self.log_event("", f"Silence detected{ids}")
elif event_type == "transcript.delta":
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
# Clear interim line and print final
print(" " * 80, end="\r")
self.log_event("", f"→ You: {text}")
else:
# Interim result - show with indicator (overwrite same line, as in mic_client)
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [listening] {display_text}".ljust(80), end="\r")
elif event_type == "ttfb":
elif event_type == "transcript.final":
text = event.get("text", "")
print(" " * 80, end="\r")
self.log_event("", f"→ You: {text}{ids}")
elif event_type == "metrics.ttfb":
latency_ms = event.get("latencyMs", 0)
self.log_event("", f"[TTFB] Server latency: {latency_ms}ms")
elif event_type == "llmResponse":
elif event_type == "assistant.response.delta":
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}")
elif self.verbose:
# Show streaming chunks only in verbose mode
self.log_event("", f"LLM: {text}")
elif event_type == "trackStart":
if self.verbose and text:
self.log_event("", f"LLM: {text}{ids}")
elif event_type == "assistant.response.final":
text = event.get("text", "")
if text:
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}")
elif event_type == "output.audio.start":
self.track_started = True
self.response_start_time = time.time()
self.waiting_for_first_audio = True
self.log_event("", "Bot started speaking")
elif event_type == "trackEnd":
self.log_event("", f"Bot started speaking{ids}")
elif event_type == "output.audio.end":
self.track_ended = True
self.log_event("", "Bot finished speaking")
elif event_type == "interrupt":
self.log_event("", "Bot interrupted!")
self.log_event("", f"Bot finished speaking{ids}")
elif event_type == "response.interrupted":
self.log_event("", f"Bot interrupted!{ids}")
elif event_type == "error":
self.log_event("!", f"Error: {event.get('error')}")
elif event_type == "hangup":
self.log_event("", f"Hangup: {event.get('reason')}")
self.log_event("!", f"Error: {event.get('message')}{ids}")
elif event_type == "session.stopped":
self.log_event("", f"Session stopped: {event.get('reason')}{ids}")
self.running = False
else:
self.log_event("", f"Event: {event_type}")
self.log_event("", f"Event: {event_type}{ids}")
def save_output_wav(self) -> None:
"""Save received audio to output WAV file."""
@@ -359,12 +411,16 @@ class WavFileClient:
# Connect to server
await self.connect()
# Wait for answer
await asyncio.sleep(0.5)
# Start receiver task
receiver_task = asyncio.create_task(self.receiver())
# Wait for session.started before streaming audio
ready_start = time.time()
while self.running and not self.session_ready:
if time.time() - ready_start > 8.0:
raise TimeoutError("Timeout waiting for session.started")
await asyncio.sleep(0.05)
# Send audio
await self.audio_sender(audio_data)
@@ -464,6 +520,21 @@ async def main():
default=16000,
help="Target sample rate for audio (default: 16000)"
)
parser.add_argument(
"--app-id",
default="assistant_demo",
help="Stable app/assistant identifier for server-side config lookup"
)
parser.add_argument(
"--channel",
default="wav_client",
help="Client channel name"
)
parser.add_argument(
"--config-version-id",
default="local-dev",
help="Optional config version identifier"
)
parser.add_argument(
"--chunk-duration",
type=int,
@@ -481,6 +552,17 @@ async def main():
action="store_true",
help="Enable verbose output"
)
parser.add_argument(
"--track-debug",
action="store_true",
help="Print event trackId for protocol debugging"
)
parser.add_argument(
"--tail-silence-ms",
type=int,
default=800,
help="Trailing silence to send after WAV playback for EOU detection (default: 800)"
)
args = parser.parse_args()
@@ -488,10 +570,15 @@ async def main():
url=args.url,
input_file=args.input,
output_file=args.output,
app_id=args.app_id,
channel=args.channel,
config_version_id=args.config_version_id,
sample_rate=args.sample_rate,
chunk_duration_ms=args.chunk_duration,
wait_time=args.wait_time,
verbose=args.verbose
verbose=args.verbose,
track_debug=args.track_debug,
tail_silence_ms=args.tail_silence_ms,
)
await client.run()

View File

@@ -401,6 +401,9 @@
const targetSampleRate = 16000;
const playbackStopRampSec = 0.008;
const appId = "assistant_demo";
const channel = "web_client";
const configVersionId = "local-dev";
function logLine(type, text, data) {
const time = new Date().toLocaleTimeString();
@@ -604,15 +607,35 @@
logLine("sys", `${cmd.type}`, cmd);
}
function eventIdsSuffix(event) {
const data = event && typeof event.data === "object" && event.data ? event.data : {};
const keys = ["turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id"];
const parts = [];
for (const key of keys) {
const value = data[key] || event[key];
if (value) parts.push(`${key}=${value}`);
}
return parts.length ? ` [${parts.join(" ")}]` : "";
}
function handleEvent(event) {
const type = event.type || "unknown";
logLine("event", type, event);
const ids = eventIdsSuffix(event);
logLine("event", `${type}${ids}`, event);
if (type === "hello.ack") {
sendCommand({
type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
metadata: {
appId,
channel,
configVersionId,
},
});
}
if (type === "config.resolved") {
logLine("sys", "config.resolved", event.config || {});
}
if (type === "transcript.final") {
if (event.text) {
setInterim("You", "");

View File

@@ -1,7 +1,8 @@
"""WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationError
def now_ms() -> int:
@@ -11,37 +12,66 @@ def now_ms() -> int:
return int(time.time() * 1000)
class _StrictModel(BaseModel):
"""Protocol models reject unknown fields to enforce WS v1 schema."""
model_config = ConfigDict(extra="forbid")
# Client -> Server messages
class HelloMessage(BaseModel):
class HelloAuth(_StrictModel):
apiKey: Optional[str] = None
jwt: Optional[str] = None
class HelloMessage(_StrictModel):
type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
class SessionStartMessage(BaseModel):
class SessionStartAudio(_StrictModel):
encoding: Literal["pcm_s16le"] = "pcm_s16le"
sample_rate_hz: Literal[16000] = 16000
channels: Literal[1] = 1
class SessionStartMessage(_StrictModel):
type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel):
class SessionStopMessage(_StrictModel):
type: Literal["session.stop"]
reason: Optional[str] = None
class InputTextMessage(BaseModel):
class InputTextMessage(_StrictModel):
type: Literal["input.text"]
text: str
class ResponseCancelMessage(BaseModel):
class ResponseCancelMessage(_StrictModel):
type: Literal["response.cancel"]
graceful: bool = False
class ToolCallResultsMessage(BaseModel):
class ToolCallResultStatus(_StrictModel):
code: int
message: str
class ToolCallResult(_StrictModel):
tool_call_id: str
name: str
output: Any = None
status: ToolCallResultStatus
class ToolCallResultsMessage(_StrictModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
results: list[ToolCallResult] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = {
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}")
try:
return msg_class(**data)
except ValidationError as exc:
details = []
for err in exc.errors():
loc = ".".join(str(part) for part in err.get("loc", ()))
msg = err.get("msg", "invalid field")
details.append(f"{loc}: {msg}" if loc else msg)
raise ValueError("; ".join(details)) from exc
# Server -> Client event helpers

View File

@@ -17,6 +17,7 @@ pydantic>=2.5.3
pydantic-settings>=2.1.0
python-dotenv>=1.0.0
toml>=0.10.2
pyyaml>=6.0.1
# Logging
loguru>=0.7.2

View File

@@ -7,10 +7,10 @@ for real-time voice conversation.
import os
import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any
from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable
from loguru import logger
from app.backend_client import search_knowledge_context
from app.backend_adapters import build_backend_adapter_from_settings
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai
@@ -37,20 +37,21 @@ class OpenAILLMService(BaseLLMService):
base_url: Optional[str] = None,
system_prompt: Optional[str] = None,
knowledge_config: Optional[Dict[str, Any]] = None,
knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None,
):
"""
Initialize OpenAI LLM service.
Args:
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars)
base_url: Custom API base URL (for Azure or compatible APIs)
system_prompt: Default system prompt for conversations
"""
super().__init__(model=model)
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = base_url or os.getenv("OPENAI_API_URL")
self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY")
self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
self.system_prompt = system_prompt or (
"You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational. "
@@ -60,6 +61,11 @@ class OpenAILLMService(BaseLLMService):
self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
if knowledge_searcher is None:
adapter = build_backend_adapter_from_settings()
self._knowledge_searcher = adapter.search_knowledge_context
else:
self._knowledge_searcher = knowledge_searcher
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5
@@ -224,7 +230,7 @@ class OpenAILLMService(BaseLLMService):
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await search_knowledge_context(
results = await self._knowledge_searcher(
kb_id=kb_id,
query=latest_user,
n_results=n_results,

View File

@@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti
import asyncio
import io
import os
import wave
from typing import AsyncIterator, Optional, Callable, Awaitable
from loguru import logger
@@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService):
def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: str = "FunAudioLLM/SenseVoiceSmall",
sample_rate: int = 16000,
language: str = "auto",
@@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService):
Args:
api_key: Provider API key
api_url: Provider API URL (defaults to SiliconFlow endpoint)
model: ASR model name or alias
sample_rate: Audio sample rate (16000 recommended)
language: Language code (auto for automatic detection)
@@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService):
if not AIOHTTP_AVAILABLE:
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
self.api_key = api_key
self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
self.model = self.MODELS.get(model.lower(), model)
self.interim_interval_ms = interim_interval_ms
self.min_audio_for_interim_ms = min_audio_for_interim_ms
@@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService):
async def connect(self) -> None:
"""Connect to the service."""
if not self.api_key:
raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.")
self._session = aiohttp.ClientSession(
headers={
"Authorization": f"Bearer {self.api_key}"
@@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService):
)
form_data.add_field('model', self.model)
async with self._session.post(self.API_URL, data=form_data) as response:
async with self._session.post(self.api_url, data=form_data) as response:
if response.status == 200:
result = await response.json()
text = result.get("text", "").strip()

View File

@@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
voice: str = "anna",
model: str = "FunAudioLLM/CosyVoice2-0.5B",
sample_rate: int = 16000,
@@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService):
Initialize OpenAI-compatible TTS service.
Args:
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars)
api_url: Provider API URL (defaults to SiliconFlow endpoint)
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
model: Model name
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
@@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService):
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
self.model = model
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech"
self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event()
@@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
async def connect(self) -> None:
"""Initialize HTTP session."""
if not self.api_key:
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.")
self._session = aiohttp.ClientSession(
headers={

View File

@@ -0,0 +1,252 @@
import os
from pathlib import Path
import pytest
os.environ.setdefault("LLM_API_KEY", "test-openai-key")
os.environ.setdefault("TTS_API_KEY", "test-tts-key")
os.environ.setdefault("ASR_API_KEY", "test-asr-key")
from app.config import load_settings
def _write_yaml(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
def _full_agent_yaml(llm_model: str = "gpt-4o-mini", llm_key: str = "test-openai-key") -> str:
return f"""
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.63
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
provider: openai_compatible
model: {llm_model}
temperature: 0.2
api_key: {llm_key}
api_url: https://example-llm.invalid/v1
tts:
provider: openai_compatible
api_key: test-tts-key
api_url: https://example-tts.invalid/v1/audio/speech
model: FunAudioLLM/CosyVoice2-0.5B
voice: anna
speed: 1.0
asr:
provider: openai_compatible
api_key: test-asr-key
api_url: https://example-asr.invalid/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: You are a strict test assistant.
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60
""".strip()
def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
config_dir = tmp_path / "config" / "agents"
_write_yaml(
config_dir / "support.yaml",
_full_agent_yaml(llm_model="gpt-4.1-mini"),
)
settings = load_settings(
argv=["--agent-profile", "support"],
)
assert settings.llm_model == "gpt-4.1-mini"
assert settings.llm_temperature == 0.2
assert settings.vad_threshold == 0.63
assert settings.agent_config_source == "cli_profile"
assert settings.agent_config_path == str((config_dir / "support.yaml").resolve())
def test_cli_path_has_higher_priority_than_env(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
env_file = tmp_path / "config" / "agents" / "env.yaml"
cli_file = tmp_path / "config" / "agents" / "cli.yaml"
_write_yaml(env_file, _full_agent_yaml(llm_model="env-model"))
_write_yaml(cli_file, _full_agent_yaml(llm_model="cli-model"))
monkeypatch.setenv("AGENT_CONFIG_PATH", str(env_file))
settings = load_settings(argv=["--agent-config", str(cli_file)])
assert settings.llm_model == "cli-model"
assert settings.agent_config_source == "cli_path"
assert settings.agent_config_path == str(cli_file.resolve())
def test_default_yaml_is_loaded_without_args_or_env(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
default_file = tmp_path / "config" / "agents" / "default.yaml"
_write_yaml(default_file, _full_agent_yaml(llm_model="from-default"))
monkeypatch.delenv("AGENT_CONFIG_PATH", raising=False)
monkeypatch.delenv("AGENT_PROFILE", raising=False)
settings = load_settings(argv=[])
assert settings.llm_model == "from-default"
assert settings.agent_config_source == "default"
assert settings.agent_config_path == str(default_file.resolve())
def test_missing_required_agent_settings_fail(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "missing-required.yaml"
_write_yaml(
file_path,
"""
agent:
llm:
model: gpt-4o-mini
""".strip(),
)
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_blank_required_provider_key_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "blank-key.yaml"
_write_yaml(file_path, _full_agent_yaml(llm_key=""))
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_missing_tts_api_url_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "missing-tts-url.yaml"
_write_yaml(
file_path,
_full_agent_yaml().replace(
" api_url: https://example-tts.invalid/v1/audio/speech\n",
"",
),
)
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_missing_asr_api_url_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "missing-asr-url.yaml"
_write_yaml(
file_path,
_full_agent_yaml().replace(
" api_url: https://example-asr.invalid/v1/audio/transcriptions\n",
"",
),
)
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_agent_yaml_unknown_key_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "bad-agent.yaml"
_write_yaml(file_path, _full_agent_yaml() + "\n unknown_option: true")
with pytest.raises(ValueError, match="Unknown agent config keys"):
load_settings(argv=["--agent-config", str(file_path)])
def test_legacy_siliconflow_section_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "legacy-siliconflow.yaml"
_write_yaml(
file_path,
"""
agent:
siliconflow:
api_key: x
""".strip(),
)
with pytest.raises(ValueError, match="Section 'siliconflow' is no longer supported"):
load_settings(argv=["--agent-config", str(file_path)])
def test_agent_yaml_missing_env_reference_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "bad-ref.yaml"
_write_yaml(
file_path,
_full_agent_yaml(llm_key="${UNSET_LLM_API_KEY}"),
)
with pytest.raises(ValueError, match="Missing environment variable"):
load_settings(argv=["--agent-config", str(file_path)])
def test_agent_yaml_tools_list_is_loaded(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "tools-agent.yaml"
_write_yaml(
file_path,
_full_agent_yaml()
+ """
tools:
- current_time
- name: weather
description: Get weather by city.
parameters:
type: object
properties:
city:
type: string
required: [city]
executor: server
""",
)
settings = load_settings(argv=["--agent-config", str(file_path)])
assert isinstance(settings.tools, list)
assert settings.tools[0] == "current_time"
assert settings.tools[1]["name"] == "weather"
assert settings.tools[1]["executor"] == "server"
def test_agent_yaml_tools_must_be_list(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "bad-tools-agent.yaml"
_write_yaml(
file_path,
_full_agent_yaml()
+ """
tools:
weather:
executor: server
""",
)
with pytest.raises(ValueError, match="Agent config key 'tools' must be a list"):
load_settings(argv=["--agent-config", str(file_path)])

View File

@@ -0,0 +1,150 @@
import aiohttp
import pytest
from app.backend_adapters import (
HistoryDisabledBackendAdapter,
HttpBackendAdapter,
NullBackendAdapter,
build_backend_adapter,
)
@pytest.mark.asyncio
async def test_build_backend_adapter_without_url_returns_null_adapter():
adapter = build_backend_adapter(
backend_url=None,
backend_mode="auto",
history_enabled=True,
timeout_sec=3,
)
assert isinstance(adapter, NullBackendAdapter)
assert await adapter.fetch_assistant_config("assistant_1") is None
assert (
await adapter.create_call_record(
user_id=1,
assistant_id="assistant_1",
source="debug",
)
is None
)
assert (
await adapter.add_transcript(
call_id="call_1",
turn_index=0,
speaker="human",
content="hi",
start_ms=0,
end_ms=100,
confidence=0.9,
duration_ms=100,
)
is False
)
assert (
await adapter.finalize_call_record(
call_id="call_1",
status="connected",
duration_seconds=2,
)
is False
)
assert await adapter.search_knowledge_context(kb_id="kb_1", query="hello", n_results=3) == []
assert await adapter.fetch_tool_resource("tool_1") is None
@pytest.mark.asyncio
async def test_http_backend_adapter_create_call_record_posts_expected_payload(monkeypatch):
captured = {}
class _FakeResponse:
def __init__(self, status=200, payload=None):
self.status = status
self._payload = payload if payload is not None else {}
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def json(self):
return self._payload
def raise_for_status(self):
if self.status >= 400:
raise RuntimeError("http_error")
class _FakeClientSession:
def __init__(self, timeout=None):
captured["timeout"] = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
def post(self, url, json=None):
captured["url"] = url
captured["json"] = json
return _FakeResponse(status=200, payload={"id": "call_123"})
monkeypatch.setattr("app.backend_adapters.aiohttp.ClientSession", _FakeClientSession)
adapter = build_backend_adapter(
backend_url="http://localhost:8100",
backend_mode="auto",
history_enabled=True,
timeout_sec=7,
)
assert isinstance(adapter, HttpBackendAdapter)
call_id = await adapter.create_call_record(
user_id=99,
assistant_id="assistant_9",
source="debug",
)
assert call_id == "call_123"
assert captured["url"] == "http://localhost:8100/api/history"
assert captured["json"] == {
"user_id": 99,
"assistant_id": "assistant_9",
"source": "debug",
"status": "connected",
}
assert isinstance(captured["timeout"], aiohttp.ClientTimeout)
assert captured["timeout"].total == 7
@pytest.mark.asyncio
async def test_backend_mode_disabled_forces_null_even_with_url():
adapter = build_backend_adapter(
backend_url="http://localhost:8100",
backend_mode="disabled",
history_enabled=True,
timeout_sec=7,
)
assert isinstance(adapter, NullBackendAdapter)
@pytest.mark.asyncio
async def test_history_disabled_wraps_backend_adapter():
adapter = build_backend_adapter(
backend_url="http://localhost:8100",
backend_mode="auto",
history_enabled=False,
timeout_sec=7,
)
assert isinstance(adapter, HistoryDisabledBackendAdapter)
assert await adapter.create_call_record(user_id=1, assistant_id="a1", source="debug") is None
assert await adapter.add_transcript(
call_id="c1",
turn_index=0,
speaker="human",
content="hi",
start_ms=0,
end_ms=10,
duration_ms=10,
) is False

View File

@@ -0,0 +1,147 @@
import asyncio
import time
import pytest
from core.history_bridge import SessionHistoryBridge
class _FakeHistoryWriter:
def __init__(self, *, add_delay_s: float = 0.0, add_result: bool = True):
self.add_delay_s = add_delay_s
self.add_result = add_result
self.created_call_ids = []
self.transcripts = []
self.finalize_calls = 0
self.finalize_statuses = []
self.finalize_at = None
self.last_transcript_at = None
async def create_call_record(self, *, user_id: int, assistant_id: str | None, source: str = "debug"):
_ = (user_id, assistant_id, source)
call_id = "call_test_1"
self.created_call_ids.append(call_id)
return call_id
async def add_transcript(
self,
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: float | None = None,
duration_ms: int | None = None,
) -> bool:
_ = confidence
if self.add_delay_s > 0:
await asyncio.sleep(self.add_delay_s)
self.transcripts.append(
{
"call_id": call_id,
"turn_index": turn_index,
"speaker": speaker,
"content": content,
"start_ms": start_ms,
"end_ms": end_ms,
"duration_ms": duration_ms,
}
)
self.last_transcript_at = time.monotonic()
return self.add_result
async def finalize_call_record(self, *, call_id: str, status: str, duration_seconds: int) -> bool:
_ = (call_id, duration_seconds)
self.finalize_calls += 1
self.finalize_statuses.append(status)
self.finalize_at = time.monotonic()
return True
@pytest.mark.asyncio
async def test_slow_backend_does_not_block_enqueue():
writer = _FakeHistoryWriter(add_delay_s=0.15, add_result=True)
bridge = SessionHistoryBridge(
history_writer=writer,
enabled=True,
queue_max_size=32,
retry_max_attempts=0,
retry_backoff_sec=0.01,
finalize_drain_timeout_sec=1.0,
)
try:
call_id = await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug")
assert call_id == "call_test_1"
t0 = time.perf_counter()
queued = bridge.enqueue_turn(role="user", text="hello world")
elapsed_s = time.perf_counter() - t0
assert queued is True
assert elapsed_s < 0.02
await bridge.finalize(status="connected")
assert len(writer.transcripts) == 1
assert writer.finalize_calls == 1
finally:
await bridge.shutdown()
@pytest.mark.asyncio
async def test_failing_backend_retries_but_enqueue_remains_non_blocking():
writer = _FakeHistoryWriter(add_delay_s=0.01, add_result=False)
bridge = SessionHistoryBridge(
history_writer=writer,
enabled=True,
queue_max_size=32,
retry_max_attempts=2,
retry_backoff_sec=0.01,
finalize_drain_timeout_sec=0.5,
)
try:
await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug")
t0 = time.perf_counter()
assert bridge.enqueue_turn(role="assistant", text="retry me")
elapsed_s = time.perf_counter() - t0
assert elapsed_s < 0.02
await bridge.finalize(status="connected")
# Initial try + 2 retries
assert len(writer.transcripts) == 3
assert writer.finalize_calls == 1
finally:
await bridge.shutdown()
@pytest.mark.asyncio
async def test_finalize_is_idempotent_and_waits_for_queue_drain():
writer = _FakeHistoryWriter(add_delay_s=0.05, add_result=True)
bridge = SessionHistoryBridge(
history_writer=writer,
enabled=True,
queue_max_size=32,
retry_max_attempts=0,
retry_backoff_sec=0.01,
finalize_drain_timeout_sec=1.0,
)
try:
await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug")
assert bridge.enqueue_turn(role="user", text="first")
ok_1 = await bridge.finalize(status="connected")
ok_2 = await bridge.finalize(status="connected")
assert ok_1 is True
assert ok_2 is True
assert len(writer.transcripts) == 1
assert writer.finalize_calls == 1
assert writer.last_transcript_at is not None
assert writer.finalize_at is not None
assert writer.finalize_at >= writer.last_transcript_at
finally:
await bridge.shutdown()

View File

@@ -92,16 +92,44 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl
return pipeline, events
def test_pipeline_uses_default_tools_from_settings(monkeypatch):
monkeypatch.setattr(
"core.duplex_pipeline.settings.tools",
[
"current_time",
{
"name": "weather",
"description": "Get weather by city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"executor": "server",
},
],
)
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
cfg = pipeline.resolved_runtime_config()
assert cfg["tools"]["allowlist"] == ["current_time", "weather"]
schemas = pipeline._resolved_tool_schemas()
names = [s.get("function", {}).get("name") for s in schemas if isinstance(s, dict)]
assert "current_time" in names
assert "weather" in names
@pytest.mark.asyncio
async def test_ws_message_parses_tool_call_results():
msg = parse_client_message(
{
"type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
"results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}],
}
)
assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0]["tool_call_id"] == "call_1"
assert msg.results[0].tool_call_id == "call_1"
@pytest.mark.asyncio

View File

@@ -6,6 +6,7 @@ const getApiBaseUrl = (): string => {
const configured = import.meta.env.VITE_API_BASE_URL || DEFAULT_API_BASE_URL;
return trimTrailingSlash(configured);
};
export { getApiBaseUrl };
type RequestOptions = {
method?: 'GET' | 'POST' | 'PUT' | 'DELETE';

View File

@@ -1,7 +1,11 @@
import { ASRModel, Assistant, CallLog, InteractionDetail, KnowledgeBase, KnowledgeDocument, LLMModel, Tool, Voice, Workflow, WorkflowEdge, WorkflowNode } from '../types';
import { apiRequest } from './apiClient';
import { apiRequest, getApiBaseUrl } from './apiClient';
type AnyRecord = Record<string, any>;
const DEFAULT_LIST_LIMIT = 1000;
const withLimit = (path: string, limit: number = DEFAULT_LIST_LIMIT): string =>
`${path}${path.includes('?') ? '&' : '?'}limit=${limit}`;
const readField = <T>(obj: AnyRecord, keys: string[], fallback: T): T => {
for (const key of keys) {
@@ -185,7 +189,8 @@ const mapKnowledgeBase = (raw: AnyRecord): KnowledgeBase => ({
const toHistoryRow = (raw: AnyRecord, assistantNameMap: Map<string, string>): CallLog => {
const assistantId = readField(raw, ['assistant_id', 'assistantId'], '');
const startTime = normalizeDateLabel(readField(raw, ['started_at', 'startTime'], ''));
const type = readField(raw, ['type'], 'text');
const rawType = String(readField(raw, ['type'], 'text')).toLowerCase();
const type: CallLog['type'] = rawType === 'audio' || rawType === 'video' ? rawType : 'text';
return {
id: String(readField(raw, ['id'], '')),
source: readField(raw, ['source'], 'debug') as 'debug' | 'external',
@@ -193,7 +198,7 @@ const toHistoryRow = (raw: AnyRecord, assistantNameMap: Map<string, string>): Ca
startTime,
duration: formatDuration(readField(raw, ['duration_seconds', 'durationSeconds'], 0)),
agentName: assistantNameMap.get(String(assistantId)) || String(assistantId || 'Unknown Assistant'),
type: type === 'audio' || type === 'video' ? type : 'text',
type,
details: [],
};
};
@@ -209,7 +214,7 @@ const toHistoryDetails = (raw: AnyRecord): InteractionDetail[] => {
};
export const fetchAssistants = async (): Promise<Assistant[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/assistants');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/assistants'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapAssistant(item));
};
@@ -276,6 +281,8 @@ export const deleteAssistant = async (id: string): Promise<void> => {
export interface AssistantRuntimeConfigResponse {
assistantId: string;
configVersionId?: string;
assistant?: Record<string, any>;
sessionStartMetadata: Record<string, any>;
sources?: {
llmModelId?: string;
@@ -286,11 +293,11 @@ export interface AssistantRuntimeConfigResponse {
}
export const fetchAssistantRuntimeConfig = async (assistantId: string): Promise<AssistantRuntimeConfigResponse> => {
return apiRequest<AssistantRuntimeConfigResponse>(`/assistants/${assistantId}/runtime-config`);
return apiRequest<AssistantRuntimeConfigResponse>(`/assistants/${assistantId}/config`);
};
export const fetchVoices = async (): Promise<Voice[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/voices'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapVoice(item));
};
@@ -351,7 +358,7 @@ export const previewVoice = async (id: string, text: string, speed?: number, api
};
export const fetchASRModels = async (): Promise<ASRModel[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/asr');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/asr'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapASRModel(item));
};
@@ -418,8 +425,7 @@ export const previewASRModel = async (
formData.append('api_key', options.apiKey);
}
const base = (import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8100/api').replace(/\/+$/, '');
const url = `${base}/asr/${id}/preview`;
const url = `${getApiBaseUrl()}/asr/${id}/preview`;
const response = await fetch(url, {
method: 'POST',
body: formData,
@@ -441,7 +447,7 @@ export const previewASRModel = async (
};
export const fetchLLMModels = async (): Promise<LLMModel[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/llm');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/llm'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapLLMModel(item));
};
@@ -501,7 +507,7 @@ export const previewLLMModel = async (
};
export const fetchTools = async (): Promise<Tool[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/tools/resources');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/tools/resources'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapTool(item));
};
@@ -544,18 +550,14 @@ export const deleteTool = async (id: string): Promise<void> => {
};
export const fetchWorkflows = async (): Promise<Workflow[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/workflows');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/workflows'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapWorkflow(item));
};
export const fetchWorkflowById = async (id: string): Promise<Workflow> => {
const list = await fetchWorkflows();
const workflow = list.find((item) => item.id === id);
if (!workflow) {
throw new Error('Workflow not found');
}
return workflow;
const response = await apiRequest<AnyRecord>(`/workflows/${id}`);
return mapWorkflow(response);
};
export const createWorkflow = async (data: Partial<Workflow>): Promise<Workflow> => {
@@ -589,7 +591,7 @@ export const deleteWorkflow = async (id: string): Promise<void> => {
};
export const fetchKnowledgeBases = async (): Promise<KnowledgeBase[]> => {
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/knowledge/bases');
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/knowledge/bases'));
const list = Array.isArray(response) ? response : (response.list || []);
return list.map((item) => mapKnowledgeBase(item));
};
@@ -641,8 +643,7 @@ export const uploadKnowledgeDocument = async (kbId: string, file: File): Promise
formData.append('size', `${file.size} bytes`);
formData.append('file_type', file.type || 'application/octet-stream');
const base = (import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8100/api').replace(/\/+$/, '');
const url = `${base}/knowledge/bases/${kbId}/documents`;
const url = `${getApiBaseUrl()}/knowledge/bases/${kbId}/documents`;
const response = await fetch(url, {
method: 'POST',
body: formData,
@@ -716,8 +717,8 @@ export const searchKnowledgeBase = async (
export const fetchHistory = async (): Promise<CallLog[]> => {
const [historyResp, assistantsResp] = await Promise.all([
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/history'),
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/assistants'),
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/history')),
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/assistants')),
]);
const assistantList = Array.isArray(assistantsResp) ? assistantsResp : (assistantsResp.list || []);

View File

@@ -11,7 +11,8 @@
],
"skipLibCheck": true,
"types": [
"node"
"node",
"vite/client"
],
"moduleResolution": "bundler",
"isolatedModules": true,