Unify db api
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
# OS artifacts
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Workspace runtime data
|
||||
data/
|
||||
8
api/.gitignore
vendored
8
api/.gitignore
vendored
@@ -36,8 +36,12 @@ env/
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Vector store data
|
||||
data/vector_store/
|
||||
# Runtime data (SQLite, vector store, uploads, generated artifacts)
|
||||
data/**
|
||||
!data/
|
||||
!data/.gitkeep
|
||||
!data/vector_store/
|
||||
data/vector_store/**
|
||||
!data/vector_store/.gitkeep
|
||||
|
||||
# IDE
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import Assistant, LLMModel, ASRModel, Voice
|
||||
from ..schemas import (
|
||||
AssistantCreate, AssistantUpdate, AssistantOut
|
||||
AssistantCreate, AssistantUpdate, AssistantOut, AssistantEngineConfigResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
||||
@@ -52,8 +52,13 @@ def _normalize_openai_compatible_voice_key(voice_value: str, model: str) -> str:
|
||||
return f"{model_name}:{voice_id}"
|
||||
|
||||
|
||||
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
metadata = {
|
||||
def _config_version_id(assistant: Assistant) -> str:
|
||||
updated = assistant.updated_at or assistant.created_at or datetime.utcnow()
|
||||
return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
|
||||
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
|
||||
metadata: Dict[str, Any] = {
|
||||
"systemPrompt": assistant.prompt or "",
|
||||
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
||||
"greeting": assistant.opener or "",
|
||||
@@ -64,10 +69,29 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
"minDurationMs": int(assistant.interruption_sensitivity or 500),
|
||||
},
|
||||
"services": {},
|
||||
"tools": assistant.tools or [],
|
||||
"history": {
|
||||
"assistantId": assistant.id,
|
||||
"userId": int(assistant.user_id or 1),
|
||||
"source": "debug",
|
||||
},
|
||||
}
|
||||
warnings = []
|
||||
warnings: List[str] = []
|
||||
|
||||
if assistant.llm_model_id:
|
||||
config_mode = str(assistant.config_mode or "platform").strip().lower()
|
||||
|
||||
if config_mode in {"dify", "fastgpt"}:
|
||||
metadata["services"]["llm"] = {
|
||||
"provider": "openai",
|
||||
"model": "",
|
||||
"apiKey": assistant.api_key,
|
||||
"baseUrl": assistant.api_url,
|
||||
}
|
||||
if not (assistant.api_url or "").strip():
|
||||
warnings.append(f"External LLM API URL is empty for mode: {assistant.config_mode}")
|
||||
if not (assistant.api_key or "").strip():
|
||||
warnings.append(f"External LLM API key is empty for mode: {assistant.config_mode}")
|
||||
elif assistant.llm_model_id:
|
||||
llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first()
|
||||
if llm:
|
||||
metadata["services"]["llm"] = {
|
||||
@@ -87,6 +111,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
"provider": asr_provider,
|
||||
"model": asr.model_name or asr.name,
|
||||
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
|
||||
"baseUrl": asr.base_url if asr_provider == "openai_compatible" else None,
|
||||
}
|
||||
else:
|
||||
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
||||
@@ -107,6 +132,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
"provider": tts_provider,
|
||||
"model": model,
|
||||
"apiKey": voice.api_key if tts_provider == "openai_compatible" else None,
|
||||
"baseUrl": voice.base_url if tts_provider == "openai_compatible" else None,
|
||||
"voice": runtime_voice,
|
||||
"speed": assistant.speed or voice.speed,
|
||||
}
|
||||
@@ -126,10 +152,21 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
"kbId": assistant.knowledge_base_id,
|
||||
"nResults": 5,
|
||||
}
|
||||
return metadata, warnings
|
||||
|
||||
|
||||
def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[str, Any]:
|
||||
session_metadata, warnings = _resolve_runtime_metadata(db, assistant)
|
||||
config_version_id = _config_version_id(assistant)
|
||||
assistant_cfg = dict(session_metadata)
|
||||
assistant_cfg["assistantId"] = assistant.id
|
||||
assistant_cfg["configVersionId"] = config_version_id
|
||||
|
||||
return {
|
||||
"assistantId": assistant.id,
|
||||
"sessionStartMetadata": metadata,
|
||||
"configVersionId": config_version_id,
|
||||
"assistant": assistant_cfg,
|
||||
"sessionStartMetadata": session_metadata,
|
||||
"sources": {
|
||||
"llmModelId": assistant.llm_model_id,
|
||||
"asrModelId": assistant.asr_model_id,
|
||||
@@ -219,13 +256,22 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
|
||||
return assistant_to_dict(assistant)
|
||||
|
||||
|
||||
@router.get("/{id}/runtime-config")
|
||||
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
|
||||
"""Resolve assistant runtime config for engine WS session.start metadata."""
|
||||
@router.get("/{id}/config", response_model=AssistantEngineConfigResponse)
|
||||
def get_assistant_config(id: str, db: Session = Depends(get_db)):
|
||||
"""Canonical engine config endpoint consumed by engine backend adapter."""
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return _resolve_runtime_metadata(db, assistant)
|
||||
return _build_engine_assistant_config(db, assistant)
|
||||
|
||||
|
||||
@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse)
|
||||
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
|
||||
"""Legacy alias for resolved engine runtime config."""
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return _build_engine_assistant_config(db, assistant)
|
||||
|
||||
|
||||
@router.post("", response_model=AssistantOut)
|
||||
|
||||
@@ -333,6 +333,35 @@ class AssistantOut(AssistantBase):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AssistantRuntimeMetadata(BaseModel):
|
||||
"""Canonical runtime metadata payload consumed by engine session.start."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
systemPrompt: str = ""
|
||||
firstTurnMode: str = "bot_first"
|
||||
greeting: str = ""
|
||||
generatedOpenerEnabled: bool = False
|
||||
output: Dict[str, Any] = Field(default_factory=dict)
|
||||
bargeIn: Dict[str, Any] = Field(default_factory=dict)
|
||||
services: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
tools: List[Any] = Field(default_factory=list)
|
||||
knowledgeBaseId: Optional[str] = None
|
||||
knowledge: Dict[str, Any] = Field(default_factory=dict)
|
||||
history: Dict[str, Any] = Field(default_factory=dict)
|
||||
assistantId: Optional[str] = None
|
||||
configVersionId: Optional[str] = None
|
||||
|
||||
|
||||
class AssistantEngineConfigResponse(BaseModel):
|
||||
assistantId: str
|
||||
configVersionId: Optional[str] = None
|
||||
assistant: AssistantRuntimeMetadata
|
||||
sessionStartMetadata: AssistantRuntimeMetadata
|
||||
sources: Dict[str, Optional[str]] = Field(default_factory=dict)
|
||||
warnings: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssistantStats(BaseModel):
|
||||
assistant_id: str
|
||||
total_calls: int = 0
|
||||
|
||||
@@ -183,12 +183,16 @@ class TestAssistantAPI:
|
||||
|
||||
def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data):
|
||||
"""Test resolved runtime config endpoint for WS session.start metadata."""
|
||||
sample_asr_model_data["vendor"] = "OpenAI Compatible"
|
||||
llm_resp = client.post("/api/llm", json=sample_llm_model_data)
|
||||
assert llm_resp.status_code == 200
|
||||
|
||||
asr_resp = client.post("/api/asr", json=sample_asr_model_data)
|
||||
assert asr_resp.status_code == 200
|
||||
|
||||
sample_voice_data["vendor"] = "OpenAI Compatible"
|
||||
sample_voice_data["base_url"] = "https://tts.example.com/v1/audio/speech"
|
||||
sample_voice_data["api_key"] = "test-voice-key"
|
||||
voice_resp = client.post("/api/voices", json=sample_voice_data)
|
||||
assert voice_resp.status_code == 200
|
||||
voice_id = voice_resp.json()["id"]
|
||||
@@ -215,7 +219,26 @@ class TestAssistantAPI:
|
||||
assert metadata["greeting"] == "runtime opener"
|
||||
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"]
|
||||
assert metadata["services"]["tts"]["voice"] == sample_voice_data["voice_key"]
|
||||
assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"]
|
||||
|
||||
def test_get_engine_config_endpoint(self, client, sample_assistant_data):
|
||||
"""Test canonical assistant config endpoint consumed by engine backend adapter."""
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
config_resp = client.get(f"/api/assistants/{assistant_id}/config")
|
||||
assert config_resp.status_code == 200
|
||||
payload = config_resp.json()
|
||||
|
||||
assert payload["assistantId"] == assistant_id
|
||||
assert payload["assistant"]["assistantId"] == assistant_id
|
||||
assert payload["assistant"]["configVersionId"].startswith(f"asst_{assistant_id}_")
|
||||
assert payload["assistant"]["systemPrompt"] == sample_assistant_data["prompt"]
|
||||
assert payload["sessionStartMetadata"]["systemPrompt"] == sample_assistant_data["prompt"]
|
||||
assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id
|
||||
|
||||
def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data):
|
||||
sample_assistant_data["voiceOutputEnabled"] = False
|
||||
|
||||
@@ -11,9 +11,16 @@ PORT=8000
|
||||
# EXTERNAL_IP=1.2.3.4
|
||||
|
||||
# Backend bridge (optional)
|
||||
# BACKEND_MODE=auto|http|disabled
|
||||
BACKEND_MODE=auto
|
||||
BACKEND_URL=http://127.0.0.1:8100
|
||||
BACKEND_TIMEOUT_SEC=10
|
||||
HISTORY_ENABLED=true
|
||||
HISTORY_DEFAULT_USER_ID=1
|
||||
HISTORY_QUEUE_MAX_SIZE=256
|
||||
HISTORY_RETRY_MAX_ATTEMPTS=2
|
||||
HISTORY_RETRY_BACKOFF_SEC=0.2
|
||||
HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC=1.5
|
||||
|
||||
# Audio
|
||||
SAMPLE_RATE=16000
|
||||
@@ -23,57 +30,21 @@ CHUNK_SIZE_MS=20
|
||||
DEFAULT_CODEC=pcm
|
||||
MAX_AUDIO_BUFFER_SECONDS=30
|
||||
|
||||
# VAD / EOU
|
||||
VAD_TYPE=silero
|
||||
VAD_MODEL_PATH=data/vad/silero_vad.onnx
|
||||
# Higher = stricter speech detection (fewer false positives, more misses).
|
||||
VAD_THRESHOLD=0.5
|
||||
# Require this much continuous speech before utterance can be valid.
|
||||
VAD_MIN_SPEECH_DURATION_MS=100
|
||||
# Silence duration required to finalize one user turn.
|
||||
VAD_EOU_THRESHOLD_MS=800
|
||||
# Agent profile selection (optional fallback when CLI args are not used)
|
||||
# Prefer CLI:
|
||||
# python -m app.main --agent-config config/agents/default.yaml
|
||||
# python -m app.main --agent-profile default
|
||||
# AGENT_CONFIG_PATH=config/agents/default.yaml
|
||||
# AGENT_PROFILE=default
|
||||
AGENT_CONFIG_DIR=config/agents
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
# Optional for OpenAI-compatible providers.
|
||||
# OPENAI_API_URL=https://api.openai.com/v1
|
||||
LLM_MODEL=gpt-4o-mini
|
||||
LLM_TEMPERATURE=0.7
|
||||
|
||||
# TTS
|
||||
# edge: no API key needed
|
||||
# openai_compatible: compatible with SiliconFlow-style endpoints
|
||||
TTS_PROVIDER=openai_compatible
|
||||
TTS_VOICE=anna
|
||||
TTS_SPEED=1.0
|
||||
|
||||
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
|
||||
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
|
||||
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
|
||||
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
|
||||
|
||||
# ASR
|
||||
ASR_PROVIDER=openai_compatible
|
||||
# Interim cadence and minimum audio before interim decode.
|
||||
ASR_INTERIM_INTERVAL_MS=500
|
||||
ASR_MIN_AUDIO_MS=300
|
||||
# ASR start gate: ignore micro-noise, then commit to one turn once started.
|
||||
ASR_START_MIN_SPEECH_MS=160
|
||||
# Pre-roll protects beginning phonemes.
|
||||
ASR_PRE_SPEECH_MS=240
|
||||
# Tail silence protects ending phonemes.
|
||||
ASR_FINAL_TAIL_MS=120
|
||||
|
||||
# Duplex behavior
|
||||
DUPLEX_ENABLED=true
|
||||
# DUPLEX_GREETING=Hello! How can I help you today?
|
||||
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||
|
||||
# Barge-in (user interrupting assistant)
|
||||
# Min user speech duration needed to interrupt assistant audio.
|
||||
BARGE_IN_MIN_DURATION_MS=200
|
||||
# Allowed silence during potential barge-in (ms) before reset.
|
||||
BARGE_IN_SILENCE_TOLERANCE_MS=60
|
||||
# Optional: provider credentials referenced from YAML, e.g. ${LLM_API_KEY}
|
||||
# LLM_API_KEY=your_llm_api_key_here
|
||||
# LLM_API_URL=https://api.openai.com/v1
|
||||
# TTS_API_KEY=your_tts_api_key_here
|
||||
# TTS_API_URL=https://api.example.com/v1/audio/speech
|
||||
# ASR_API_KEY=your_asr_api_key_here
|
||||
# ASR_API_URL=https://api.example.com/v1/audio/transcriptions
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
2
engine/.gitignore
vendored
2
engine/.gitignore
vendored
@@ -146,3 +146,5 @@ cython_debug/
|
||||
recordings/
|
||||
logs/
|
||||
running/
|
||||
|
||||
config/agents/default.yaml
|
||||
|
||||
@@ -14,6 +14,57 @@ It is currently in an early, experimental stage.
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
使用 agent profile(推荐)
|
||||
|
||||
```
|
||||
python -m app.main --agent-profile default
|
||||
```
|
||||
|
||||
使用指定 YAML
|
||||
|
||||
```
|
||||
python -m app.main --agent-config config/agents/default.yaml
|
||||
```
|
||||
|
||||
Agent 配置路径优先级
|
||||
1. `--agent-config`
|
||||
2. `--agent-profile`(映射到 `config/agents/<profile>.yaml`)
|
||||
3. `AGENT_CONFIG_PATH`
|
||||
4. `AGENT_PROFILE`
|
||||
5. `config/agents/default.yaml`(若存在)
|
||||
|
||||
说明
|
||||
- Agent 相关配置是严格模式:YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。
|
||||
- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。
|
||||
- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。
|
||||
- 现在支持在 Agent YAML 中配置 `agent.tools`(列表),用于声明运行时可调用工具。
|
||||
- 工具配置示例见 `config/agents/tools.yaml`。
|
||||
|
||||
## Backend Integration
|
||||
|
||||
Engine runtime now supports adapter-based backend integration:
|
||||
|
||||
- `BACKEND_MODE=auto|http|disabled`
|
||||
- `BACKEND_URL` + `BACKEND_TIMEOUT_SEC`
|
||||
- `HISTORY_ENABLED=true|false`
|
||||
|
||||
Behavior:
|
||||
|
||||
- `auto`: use HTTP backend only when `BACKEND_URL` is set, otherwise engine-only mode.
|
||||
- `http`: force HTTP backend; falls back to engine-only mode when URL is missing.
|
||||
- `disabled`: force engine-only mode (no backend calls).
|
||||
|
||||
History write path is now asynchronous and buffered per session:
|
||||
|
||||
- `HISTORY_QUEUE_MAX_SIZE`
|
||||
- `HISTORY_RETRY_MAX_ATTEMPTS`
|
||||
- `HISTORY_RETRY_BACKOFF_SEC`
|
||||
- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC`
|
||||
|
||||
This keeps turn processing responsive even when backend history APIs are slow/failing.
|
||||
|
||||
Detailed notes: `docs/backend_integration.md`.
|
||||
|
||||
测试
|
||||
|
||||
```
|
||||
@@ -28,4 +79,4 @@ python mic_client.py
|
||||
|
||||
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
|
||||
|
||||
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.
|
||||
See `docs/ws_v1_schema.md`.
|
||||
|
||||
357
engine/app/backend_adapters.py
Normal file
357
engine/app/backend_adapters.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Backend adapter implementations for engine integration ports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class NullBackendAdapter:
|
||||
"""No-op adapter for engine-only runtime without backend dependencies."""
|
||||
|
||||
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
_ = assistant_id
|
||||
return None
|
||||
|
||||
async def create_call_record(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
_ = (user_id, assistant_id, source)
|
||||
return None
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
_ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms)
|
||||
return False
|
||||
|
||||
async def finalize_call_record(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
_ = (call_id, status, duration_seconds)
|
||||
return False
|
||||
|
||||
async def search_knowledge_context(
|
||||
self,
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
_ = (kb_id, query, n_results)
|
||||
return []
|
||||
|
||||
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
_ = tool_id
|
||||
return None
|
||||
|
||||
|
||||
class HistoryDisabledBackendAdapter:
|
||||
"""Adapter wrapper that disables history writes while keeping reads available."""
|
||||
|
||||
def __init__(self, delegate: HttpBackendAdapter | NullBackendAdapter):
|
||||
self._delegate = delegate
|
||||
|
||||
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self._delegate.fetch_assistant_config(assistant_id)
|
||||
|
||||
async def create_call_record(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
_ = (user_id, assistant_id, source)
|
||||
return None
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
_ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms)
|
||||
return False
|
||||
|
||||
async def finalize_call_record(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
_ = (call_id, status, duration_seconds)
|
||||
return False
|
||||
|
||||
async def search_knowledge_context(
|
||||
self,
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await self._delegate.search_knowledge_context(
|
||||
kb_id=kb_id,
|
||||
query=query,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self._delegate.fetch_tool_resource(tool_id)
|
||||
|
||||
|
||||
class HttpBackendAdapter:
|
||||
"""HTTP implementation of backend integration ports."""
|
||||
|
||||
def __init__(self, backend_url: str, timeout_sec: int = 10):
|
||||
base_url = str(backend_url or "").strip().rstrip("/")
|
||||
if not base_url:
|
||||
raise ValueError("backend_url is required for HttpBackendAdapter")
|
||||
self._base_url = base_url
|
||||
self._timeout_sec = timeout_sec
|
||||
|
||||
def _timeout(self) -> aiohttp.ClientTimeout:
|
||||
return aiohttp.ClientTimeout(total=self._timeout_sec)
|
||||
|
||||
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch assistant config payload from backend API.
|
||||
|
||||
Expected response shape:
|
||||
{
|
||||
"assistant": {...},
|
||||
"voice": {...} | null
|
||||
}
|
||||
"""
|
||||
url = f"{self._base_url}/api/assistants/{assistant_id}/config"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Assistant config not found: {assistant_id}")
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
payload = await resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning("Assistant config payload is not a dict; ignoring")
|
||||
return None
|
||||
return payload
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
|
||||
return None
|
||||
|
||||
async def create_call_record(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
"""Create a call record via backend history API and return call_id."""
|
||||
url = f"{self._base_url}/api/history"
|
||||
payload: Dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"assistant_id": assistant_id,
|
||||
"source": source,
|
||||
"status": "connected",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
call_id = str((data or {}).get("id") or "")
|
||||
return call_id or None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to create history call record: {exc}")
|
||||
return None
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Append a transcript segment to backend history."""
|
||||
if not call_id:
|
||||
return False
|
||||
|
||||
url = f"{self._base_url}/api/history/{call_id}/transcripts"
|
||||
payload: Dict[str, Any] = {
|
||||
"turn_index": turn_index,
|
||||
"speaker": speaker,
|
||||
"content": content,
|
||||
"confidence": confidence,
|
||||
"start_ms": start_ms,
|
||||
"end_ms": end_ms,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}")
|
||||
return False
|
||||
|
||||
async def finalize_call_record(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
"""Finalize a call record with status and duration."""
|
||||
if not call_id:
|
||||
return False
|
||||
|
||||
url = f"{self._base_url}/api/history/{call_id}"
|
||||
payload: Dict[str, Any] = {
|
||||
"status": status,
|
||||
"duration_seconds": duration_seconds,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
|
||||
async with session.put(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
|
||||
return False
|
||||
|
||||
async def search_knowledge_context(
|
||||
self,
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search backend knowledge base and return retrieval results."""
|
||||
if not kb_id or not query.strip():
|
||||
return []
|
||||
try:
|
||||
safe_n_results = max(1, int(n_results))
|
||||
except (TypeError, ValueError):
|
||||
safe_n_results = 5
|
||||
|
||||
url = f"{self._base_url}/api/knowledge/search"
|
||||
payload: Dict[str, Any] = {
|
||||
"kb_id": kb_id,
|
||||
"query": query,
|
||||
"nResults": safe_n_results,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
|
||||
return []
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
if not isinstance(data, dict):
|
||||
return []
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
return []
|
||||
return [r for r in results if isinstance(r, dict)]
|
||||
except Exception as exc:
|
||||
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
|
||||
return []
|
||||
|
||||
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch tool resource configuration from backend API."""
|
||||
if not tool_id:
|
||||
return None
|
||||
|
||||
url = f"{self._base_url}/api/tools/resources/{tool_id}"
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self._timeout()) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
return data if isinstance(data, dict) else None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def build_backend_adapter(
|
||||
*,
|
||||
backend_url: Optional[str],
|
||||
backend_mode: str = "auto",
|
||||
history_enabled: bool = True,
|
||||
timeout_sec: int = 10,
|
||||
) -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter:
|
||||
"""Create backend adapter implementation based on runtime settings."""
|
||||
mode = str(backend_mode or "auto").strip().lower()
|
||||
has_url = bool(str(backend_url or "").strip())
|
||||
|
||||
base_adapter: HttpBackendAdapter | NullBackendAdapter
|
||||
if mode in {"disabled", "off", "none", "null", "engine_only", "engine-only"}:
|
||||
base_adapter = NullBackendAdapter()
|
||||
elif mode == "http":
|
||||
if has_url:
|
||||
base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec)
|
||||
else:
|
||||
logger.warning("BACKEND_MODE=http but BACKEND_URL is empty; falling back to NullBackendAdapter")
|
||||
base_adapter = NullBackendAdapter()
|
||||
else:
|
||||
if has_url:
|
||||
base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec)
|
||||
else:
|
||||
base_adapter = NullBackendAdapter()
|
||||
|
||||
if not history_enabled:
|
||||
return HistoryDisabledBackendAdapter(base_adapter)
|
||||
return base_adapter
|
||||
|
||||
|
||||
def build_backend_adapter_from_settings() -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter:
|
||||
"""Create backend adapter using current app settings."""
|
||||
return build_backend_adapter(
|
||||
backend_url=settings.backend_url,
|
||||
backend_mode=settings.backend_mode,
|
||||
history_enabled=settings.history_enabled,
|
||||
timeout_sec=settings.backend_timeout_sec,
|
||||
)
|
||||
@@ -1,56 +1,19 @@
|
||||
"""Backend API client for assistant config and history persistence."""
|
||||
"""Compatibility wrappers around backend adapter implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
|
||||
from app.config import settings
|
||||
|
||||
def _adapter():
|
||||
return build_backend_adapter_from_settings()
|
||||
|
||||
|
||||
async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch assistant config payload from backend API.
|
||||
|
||||
Expected response shape:
|
||||
{
|
||||
"assistant": {...},
|
||||
"voice": {...} | null
|
||||
}
|
||||
"""
|
||||
if not settings.backend_url:
|
||||
logger.warning("BACKEND_URL not set; skipping assistant config fetch")
|
||||
return None
|
||||
|
||||
url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config"
|
||||
timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Assistant config not found: {assistant_id}")
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
payload = await resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning("Assistant config payload is not a dict; ignoring")
|
||||
return None
|
||||
return payload
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def _backend_base_url() -> Optional[str]:
|
||||
if not settings.backend_url:
|
||||
return None
|
||||
return settings.backend_url.rstrip("/")
|
||||
|
||||
|
||||
def _timeout() -> aiohttp.ClientTimeout:
|
||||
return aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
|
||||
"""Fetch assistant config payload from backend adapter."""
|
||||
return await _adapter().fetch_assistant_config(assistant_id)
|
||||
|
||||
|
||||
async def create_history_call_record(
|
||||
@@ -60,28 +23,11 @@ async def create_history_call_record(
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
"""Create a call record via backend history API and return call_id."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = f"{base_url}/api/history"
|
||||
payload: Dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"assistant_id": assistant_id,
|
||||
"source": source,
|
||||
"status": "connected",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
call_id = str((data or {}).get("id") or "")
|
||||
return call_id or None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to create history call record: {exc}")
|
||||
return None
|
||||
return await _adapter().create_call_record(
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
async def add_history_transcript(
|
||||
@@ -96,29 +42,16 @@ async def add_history_transcript(
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Append a transcript segment to backend history."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url or not call_id:
|
||||
return False
|
||||
|
||||
url = f"{base_url}/api/history/{call_id}/transcripts"
|
||||
payload: Dict[str, Any] = {
|
||||
"turn_index": turn_index,
|
||||
"speaker": speaker,
|
||||
"content": content,
|
||||
"confidence": confidence,
|
||||
"start_ms": start_ms,
|
||||
"end_ms": end_ms,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}")
|
||||
return False
|
||||
return await _adapter().add_transcript(
|
||||
call_id=call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=content,
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
confidence=confidence,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
|
||||
async def finalize_history_call_record(
|
||||
@@ -128,24 +61,11 @@ async def finalize_history_call_record(
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
"""Finalize a call record with status and duration."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url or not call_id:
|
||||
return False
|
||||
|
||||
url = f"{base_url}/api/history/{call_id}"
|
||||
payload: Dict[str, Any] = {
|
||||
"status": status,
|
||||
"duration_seconds": duration_seconds,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.put(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
|
||||
return False
|
||||
return await _adapter().finalize_call_record(
|
||||
call_id=call_id,
|
||||
status=status,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def search_knowledge_context(
|
||||
@@ -155,57 +75,13 @@ async def search_knowledge_context(
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search backend knowledge base and return retrieval results."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url:
|
||||
return []
|
||||
if not kb_id or not query.strip():
|
||||
return []
|
||||
try:
|
||||
safe_n_results = max(1, int(n_results))
|
||||
except (TypeError, ValueError):
|
||||
safe_n_results = 5
|
||||
|
||||
url = f"{base_url}/api/knowledge/search"
|
||||
payload: Dict[str, Any] = {
|
||||
"kb_id": kb_id,
|
||||
"query": query,
|
||||
"nResults": safe_n_results,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
|
||||
return []
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
if not isinstance(data, dict):
|
||||
return []
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
return []
|
||||
return [r for r in results if isinstance(r, dict)]
|
||||
except Exception as exc:
|
||||
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
|
||||
return []
|
||||
return await _adapter().search_knowledge_context(
|
||||
kb_id=kb_id,
|
||||
query=query,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
|
||||
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch tool resource configuration from backend API."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url or not tool_id:
|
||||
return None
|
||||
|
||||
url = f"{base_url}/api/tools/resources/{tool_id}"
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
return data if isinstance(data, dict) else None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}")
|
||||
return None
|
||||
return await _adapter().fetch_tool_resource(tool_id)
|
||||
|
||||
@@ -1,9 +1,360 @@
|
||||
"""Configuration management using Pydantic settings."""
|
||||
"""Configuration management using Pydantic settings and agent YAML profiles."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
import json
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError: # pragma: no cover - validated when agent YAML is used
|
||||
yaml = None
|
||||
|
||||
|
||||
_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}")
|
||||
_DEFAULT_AGENT_CONFIG_DIR = "config/agents"
|
||||
_DEFAULT_AGENT_CONFIG_FILE = "default.yaml"
|
||||
_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
|
||||
"vad": {
|
||||
"type": "vad_type",
|
||||
"model_path": "vad_model_path",
|
||||
"threshold": "vad_threshold",
|
||||
"min_speech_duration_ms": "vad_min_speech_duration_ms",
|
||||
"eou_threshold_ms": "vad_eou_threshold_ms",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "llm_provider",
|
||||
"model": "llm_model",
|
||||
"temperature": "llm_temperature",
|
||||
"api_key": "llm_api_key",
|
||||
"api_url": "llm_api_url",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "tts_provider",
|
||||
"api_key": "tts_api_key",
|
||||
"api_url": "tts_api_url",
|
||||
"model": "tts_model",
|
||||
"voice": "tts_voice",
|
||||
"speed": "tts_speed",
|
||||
},
|
||||
"asr": {
|
||||
"provider": "asr_provider",
|
||||
"api_key": "asr_api_key",
|
||||
"api_url": "asr_api_url",
|
||||
"model": "asr_model",
|
||||
"interim_interval_ms": "asr_interim_interval_ms",
|
||||
"min_audio_ms": "asr_min_audio_ms",
|
||||
"start_min_speech_ms": "asr_start_min_speech_ms",
|
||||
"pre_speech_ms": "asr_pre_speech_ms",
|
||||
"final_tail_ms": "asr_final_tail_ms",
|
||||
},
|
||||
"duplex": {
|
||||
"enabled": "duplex_enabled",
|
||||
"greeting": "duplex_greeting",
|
||||
"system_prompt": "duplex_system_prompt",
|
||||
},
|
||||
"barge_in": {
|
||||
"min_duration_ms": "barge_in_min_duration_ms",
|
||||
"silence_tolerance_ms": "barge_in_silence_tolerance_ms",
|
||||
},
|
||||
}
|
||||
_AGENT_SETTING_KEYS = {
|
||||
"vad_type",
|
||||
"vad_model_path",
|
||||
"vad_threshold",
|
||||
"vad_min_speech_duration_ms",
|
||||
"vad_eou_threshold_ms",
|
||||
"llm_provider",
|
||||
"llm_api_key",
|
||||
"llm_api_url",
|
||||
"llm_model",
|
||||
"llm_temperature",
|
||||
"tts_provider",
|
||||
"tts_api_key",
|
||||
"tts_api_url",
|
||||
"tts_model",
|
||||
"tts_voice",
|
||||
"tts_speed",
|
||||
"asr_provider",
|
||||
"asr_api_key",
|
||||
"asr_api_url",
|
||||
"asr_model",
|
||||
"asr_interim_interval_ms",
|
||||
"asr_min_audio_ms",
|
||||
"asr_start_min_speech_ms",
|
||||
"asr_pre_speech_ms",
|
||||
"asr_final_tail_ms",
|
||||
"duplex_enabled",
|
||||
"duplex_greeting",
|
||||
"duplex_system_prompt",
|
||||
"barge_in_min_duration_ms",
|
||||
"barge_in_silence_tolerance_ms",
|
||||
"tools",
|
||||
}
|
||||
_BASE_REQUIRED_AGENT_SETTING_KEYS = {
|
||||
"vad_type",
|
||||
"vad_model_path",
|
||||
"vad_threshold",
|
||||
"vad_min_speech_duration_ms",
|
||||
"vad_eou_threshold_ms",
|
||||
"llm_provider",
|
||||
"llm_model",
|
||||
"llm_temperature",
|
||||
"tts_provider",
|
||||
"tts_voice",
|
||||
"tts_speed",
|
||||
"asr_provider",
|
||||
"asr_interim_interval_ms",
|
||||
"asr_min_audio_ms",
|
||||
"asr_start_min_speech_ms",
|
||||
"asr_pre_speech_ms",
|
||||
"asr_final_tail_ms",
|
||||
"duplex_enabled",
|
||||
"duplex_system_prompt",
|
||||
"barge_in_min_duration_ms",
|
||||
"barge_in_silence_tolerance_ms",
|
||||
}
|
||||
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
|
||||
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
|
||||
return str(overrides.get(key) or default).strip().lower()
|
||||
|
||||
|
||||
def _is_blank(value: Any) -> bool:
|
||||
return value is None or (isinstance(value, str) and not value.strip())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentConfigSelection:
|
||||
"""Resolved agent config location and how it was selected."""
|
||||
|
||||
path: Optional[Path]
|
||||
source: str
|
||||
|
||||
|
||||
def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Parse only agent-related CLI flags from argv."""
|
||||
config_path: Optional[str] = None
|
||||
profile: Optional[str] = None
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg.startswith("--agent-config="):
|
||||
config_path = arg.split("=", 1)[1].strip() or None
|
||||
elif arg == "--agent-config" and i + 1 < len(argv):
|
||||
config_path = argv[i + 1].strip() or None
|
||||
i += 1
|
||||
elif arg.startswith("--agent-profile="):
|
||||
profile = arg.split("=", 1)[1].strip() or None
|
||||
elif arg == "--agent-profile" and i + 1 < len(argv):
|
||||
profile = argv[i + 1].strip() or None
|
||||
i += 1
|
||||
i += 1
|
||||
return config_path, profile
|
||||
|
||||
|
||||
def _agent_config_dir() -> Path:
|
||||
base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR))
|
||||
if not base_dir.is_absolute():
|
||||
base_dir = Path.cwd() / base_dir
|
||||
return base_dir.resolve()
|
||||
|
||||
|
||||
def _resolve_agent_selection(
|
||||
agent_config_path: Optional[str] = None,
|
||||
agent_profile: Optional[str] = None,
|
||||
argv: Optional[List[str]] = None,
|
||||
) -> AgentConfigSelection:
|
||||
cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:]))
|
||||
path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH")
|
||||
profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE")
|
||||
source = "none"
|
||||
candidate: Optional[Path] = None
|
||||
|
||||
if path_value:
|
||||
source = "cli_path" if (agent_config_path or cli_path) else "env_path"
|
||||
candidate = Path(path_value)
|
||||
elif profile_value:
|
||||
source = "cli_profile" if (agent_profile or cli_profile) else "env_profile"
|
||||
candidate = _agent_config_dir() / f"{profile_value}.yaml"
|
||||
else:
|
||||
fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE
|
||||
if fallback.exists():
|
||||
source = "default"
|
||||
candidate = fallback
|
||||
|
||||
if candidate is None:
|
||||
raise ValueError(
|
||||
"Agent YAML config is required. Provide --agent-config/--agent-profile "
|
||||
"or create config/agents/default.yaml."
|
||||
)
|
||||
|
||||
if not candidate.is_absolute():
|
||||
candidate = (Path.cwd() / candidate).resolve()
|
||||
else:
|
||||
candidate = candidate.resolve()
|
||||
|
||||
if not candidate.exists():
|
||||
raise ValueError(f"Agent config file not found ({source}): {candidate}")
|
||||
if not candidate.is_file():
|
||||
raise ValueError(f"Agent config path is not a file: {candidate}")
|
||||
return AgentConfigSelection(path=candidate, source=source)
|
||||
|
||||
|
||||
def _resolve_env_refs(value: Any) -> Any:
|
||||
"""Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively."""
|
||||
if isinstance(value, dict):
|
||||
return {k: _resolve_env_refs(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_resolve_env_refs(item) for item in value]
|
||||
if not isinstance(value, str) or "${" not in value:
|
||||
return value
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
env_key = match.group(1)
|
||||
default_value = match.group(2)
|
||||
env_value = os.getenv(env_key)
|
||||
if env_value is None:
|
||||
if default_value is None:
|
||||
raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}")
|
||||
return default_value
|
||||
return env_value
|
||||
|
||||
return _ENV_REF_PATTERN.sub(_replace, value)
|
||||
|
||||
|
||||
def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize YAML into flat Settings fields."""
|
||||
normalized: Dict[str, Any] = {}
|
||||
|
||||
for key, value in raw.items():
|
||||
if key == "siliconflow":
|
||||
raise ValueError(
|
||||
"Section 'siliconflow' is no longer supported. "
|
||||
"Move provider-specific fields into agent.llm / agent.asr / agent.tts."
|
||||
)
|
||||
if key == "tools":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Agent config key 'tools' must be a list")
|
||||
normalized["tools"] = value
|
||||
continue
|
||||
section_map = _AGENT_SECTION_KEY_MAP.get(key)
|
||||
if section_map is None:
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"Agent config section '{key}' must be a mapping")
|
||||
|
||||
for nested_key, nested_value in value.items():
|
||||
mapped_key = section_map.get(nested_key)
|
||||
if mapped_key is None:
|
||||
raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'")
|
||||
normalized[mapped_key] = nested_value
|
||||
|
||||
unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS)
|
||||
if unknown_keys:
|
||||
raise ValueError(
|
||||
"Unknown agent config keys in YAML: "
|
||||
+ ", ".join(unknown_keys)
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
|
||||
missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides))
|
||||
string_required = {
|
||||
"vad_type",
|
||||
"vad_model_path",
|
||||
"llm_provider",
|
||||
"llm_model",
|
||||
"tts_provider",
|
||||
"tts_voice",
|
||||
"asr_provider",
|
||||
"duplex_system_prompt",
|
||||
}
|
||||
for key in string_required:
|
||||
if key in overrides and _is_blank(overrides.get(key)):
|
||||
missing.add(key)
|
||||
|
||||
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
|
||||
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
|
||||
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
|
||||
missing.add("llm_api_key")
|
||||
|
||||
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
|
||||
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
||||
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
|
||||
missing.add("tts_api_key")
|
||||
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
|
||||
missing.add("tts_api_url")
|
||||
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
|
||||
missing.add("tts_model")
|
||||
|
||||
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
|
||||
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
||||
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
|
||||
missing.add("asr_api_key")
|
||||
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
|
||||
missing.add("asr_api_url")
|
||||
if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")):
|
||||
missing.add("asr_model")
|
||||
|
||||
return sorted(missing)
|
||||
|
||||
|
||||
def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]:
|
||||
if yaml is None:
|
||||
raise RuntimeError(
|
||||
"PyYAML is required for agent YAML configuration. Install with: pip install pyyaml"
|
||||
)
|
||||
|
||||
with selection.path.open("r", encoding="utf-8") as file:
|
||||
raw = yaml.safe_load(file) or {}
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"Agent config must be a YAML mapping: {selection.path}")
|
||||
|
||||
if "agent" in raw:
|
||||
agent_value = raw["agent"]
|
||||
if not isinstance(agent_value, dict):
|
||||
raise ValueError("The 'agent' key in YAML must be a mapping")
|
||||
raw = agent_value
|
||||
|
||||
resolved = _resolve_env_refs(raw)
|
||||
overrides = _normalize_agent_overrides(resolved)
|
||||
missing_required = _missing_required_keys(overrides)
|
||||
if missing_required:
|
||||
raise ValueError(
|
||||
f"Missing required agent settings in YAML ({selection.path}): "
|
||||
+ ", ".join(missing_required)
|
||||
)
|
||||
|
||||
overrides["agent_config_path"] = str(selection.path)
|
||||
overrides["agent_config_source"] = selection.source
|
||||
return overrides
|
||||
|
||||
|
||||
def load_settings(
|
||||
agent_config_path: Optional[str] = None,
|
||||
agent_profile: Optional[str] = None,
|
||||
argv: Optional[List[str]] = None,
|
||||
) -> "Settings":
|
||||
"""Load settings from .env and optional agent YAML."""
|
||||
selection = _resolve_agent_selection(
|
||||
agent_config_path=agent_config_path,
|
||||
agent_profile=agent_profile,
|
||||
argv=argv,
|
||||
)
|
||||
agent_overrides = _load_agent_overrides(selection)
|
||||
return Settings(**agent_overrides)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -37,30 +388,35 @@ class Settings(BaseSettings):
|
||||
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
|
||||
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
|
||||
|
||||
# OpenAI / LLM Configuration
|
||||
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
|
||||
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
|
||||
# LLM Configuration
|
||||
llm_provider: str = Field(
|
||||
default="openai",
|
||||
description="LLM provider (openai, openai_compatible, siliconflow)"
|
||||
)
|
||||
llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key")
|
||||
llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL")
|
||||
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
||||
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
||||
|
||||
# TTS Configuration
|
||||
tts_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
|
||||
description="TTS provider (edge, openai_compatible, siliconflow)"
|
||||
)
|
||||
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
|
||||
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
|
||||
tts_model: Optional[str] = Field(default=None, description="TTS model name")
|
||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||
|
||||
# SiliconFlow Configuration
|
||||
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
|
||||
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
|
||||
|
||||
# ASR Configuration
|
||||
asr_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
|
||||
description="ASR provider (openai_compatible, buffered, siliconflow)"
|
||||
)
|
||||
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
|
||||
asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key")
|
||||
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
|
||||
asr_model: Optional[str] = Field(default=None, description="ASR model name")
|
||||
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
||||
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
||||
asr_start_min_speech_ms: int = Field(
|
||||
@@ -94,6 +450,10 @@ class Settings(BaseSettings):
|
||||
description="How much silence (ms) is tolerated during potential barge-in before reset"
|
||||
)
|
||||
|
||||
# Optional tool declarations from agent YAML.
|
||||
# Supports OpenAI function schema style entries and/or shorthand string names.
|
||||
tools: List[Any] = Field(default_factory=list, description="Default tool definitions for runtime")
|
||||
|
||||
# Logging
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format (json or text)")
|
||||
@@ -118,9 +478,25 @@ class Settings(BaseSettings):
|
||||
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
|
||||
|
||||
# Backend bridge configuration (for call/transcript persistence)
|
||||
backend_mode: str = Field(
|
||||
default="auto",
|
||||
description="Backend integration mode: auto | http | disabled"
|
||||
)
|
||||
backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)")
|
||||
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
|
||||
history_enabled: bool = Field(default=True, description="Enable history write bridge")
|
||||
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
|
||||
history_queue_max_size: int = Field(default=256, description="Max buffered transcript writes per session")
|
||||
history_retry_max_attempts: int = Field(default=2, description="Retry attempts for each transcript write")
|
||||
history_retry_backoff_sec: float = Field(default=0.2, description="Base retry backoff for transcript writes")
|
||||
history_finalize_drain_timeout_sec: float = Field(
|
||||
default=1.5,
|
||||
description="Max wait before finalizing history when queue is still draining"
|
||||
)
|
||||
|
||||
# Agent YAML metadata
|
||||
agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path")
|
||||
agent_config_source: str = Field(default="none", description="How the agent YAML was selected")
|
||||
|
||||
@property
|
||||
def chunk_size_bytes(self) -> int:
|
||||
@@ -146,7 +522,7 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
settings = load_settings()
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
|
||||
@@ -20,11 +20,11 @@ except ImportError:
|
||||
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
||||
|
||||
from app.config import settings
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
|
||||
from core.session import Session
|
||||
from processors.tracks import Resampled16kTrack
|
||||
from core.events import get_event_bus, reset_event_bus
|
||||
from models.ws_v1 import ev
|
||||
|
||||
# Check interval for heartbeat/timeout (seconds)
|
||||
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
||||
@@ -54,9 +54,7 @@ async def heartbeat_and_timeout_task(
|
||||
break
|
||||
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
||||
try:
|
||||
await transport.send_event({
|
||||
**ev("heartbeat"),
|
||||
})
|
||||
await session.send_heartbeat()
|
||||
last_heartbeat_at[0] = now
|
||||
except Exception as e:
|
||||
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
|
||||
@@ -78,6 +76,7 @@ app.add_middleware(
|
||||
|
||||
# Active sessions storage
|
||||
active_sessions: Dict[str, Session] = {}
|
||||
backend_gateway = build_backend_adapter_from_settings()
|
||||
|
||||
# Configure logging
|
||||
logger.remove()
|
||||
@@ -167,7 +166,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
# Create transport and session
|
||||
transport = SocketTransport(websocket)
|
||||
session = Session(session_id, transport)
|
||||
session = Session(session_id, transport, backend_gateway=backend_gateway)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebSocket connection established: {session_id}")
|
||||
@@ -246,7 +245,7 @@ async def webrtc_endpoint(websocket: WebSocket):
|
||||
|
||||
# Create transport and session
|
||||
transport = WebRtcTransport(websocket, pc)
|
||||
session = Session(session_id, transport)
|
||||
session = Session(session_id, transport, backend_gateway=backend_gateway)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebRTC connection established: {session_id}")
|
||||
@@ -360,6 +359,12 @@ async def startup_event():
|
||||
logger.info(f"Server: {settings.host}:{settings.port}")
|
||||
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
||||
logger.info(f"VAD model: {settings.vad_model_path}")
|
||||
if settings.agent_config_path:
|
||||
logger.info(
|
||||
f"Agent config loaded ({settings.agent_config_source}): {settings.agent_config_path}"
|
||||
)
|
||||
else:
|
||||
logger.info("Agent config: none (using .env/default agent values)")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
|
||||
50
engine/config/agents/example.yaml
Normal file
50
engine/config/agents/example.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# Agent behavior configuration (safe to edit per profile)
|
||||
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
|
||||
# Infra/server/network settings should stay in .env.
|
||||
|
||||
agent:
|
||||
vad:
|
||||
type: silero
|
||||
model_path: data/vad/silero_vad.onnx
|
||||
threshold: 0.5
|
||||
min_speech_duration_ms: 100
|
||||
eou_threshold_ms: 800
|
||||
|
||||
llm:
|
||||
# provider: openai | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
model: deepseek-v3
|
||||
temperature: 0.7
|
||||
# Required: no fallback. You can still reference env explicitly.
|
||||
api_key: your_llm_api_key
|
||||
# Optional for OpenAI-compatible endpoints:
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: edge | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
api_key: your_tts_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||
model: FunAudioLLM/CosyVoice2-0.5B
|
||||
voice: anna
|
||||
speed: 1.0
|
||||
|
||||
asr:
|
||||
# provider: buffered | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
api_key: you_asr_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
model: FunAudioLLM/SenseVoiceSmall
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
pre_speech_ms: 240
|
||||
final_tail_ms: 120
|
||||
|
||||
duplex:
|
||||
enabled: true
|
||||
system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||
|
||||
barge_in:
|
||||
min_duration_ms: 200
|
||||
silence_tolerance_ms: 60
|
||||
73
engine/config/agents/tools.yaml
Normal file
73
engine/config/agents/tools.yaml
Normal file
@@ -0,0 +1,73 @@
|
||||
# Agent behavior configuration with tool declarations.
|
||||
# This profile is an example only.
|
||||
|
||||
agent:
|
||||
vad:
|
||||
type: silero
|
||||
model_path: data/vad/silero_vad.onnx
|
||||
threshold: 0.5
|
||||
min_speech_duration_ms: 100
|
||||
eou_threshold_ms: 800
|
||||
|
||||
llm:
|
||||
# provider: openai | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
model: deepseek-v3
|
||||
temperature: 0.7
|
||||
api_key: your_llm_api_key
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: edge | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
api_key: your_tts_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||
model: FunAudioLLM/CosyVoice2-0.5B
|
||||
voice: anna
|
||||
speed: 1.0
|
||||
|
||||
asr:
|
||||
# provider: buffered | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
api_key: your_asr_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
model: FunAudioLLM/SenseVoiceSmall
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
pre_speech_ms: 240
|
||||
final_tail_ms: 120
|
||||
|
||||
duplex:
|
||||
enabled: true
|
||||
system_prompt: You are a helpful voice assistant with tool-calling support.
|
||||
|
||||
barge_in:
|
||||
min_duration_ms: 200
|
||||
silence_tolerance_ms: 60
|
||||
|
||||
# Tool declarations consumed by the engine at startup.
|
||||
# - String form enables built-in/default tool schema when available.
|
||||
# - Object form provides OpenAI function schema + executor hint.
|
||||
tools:
|
||||
- current_time
|
||||
- calculator
|
||||
- name: weather
|
||||
description: Get weather by city name.
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
city:
|
||||
type: string
|
||||
description: City name, for example "San Francisco".
|
||||
required: [city]
|
||||
executor: server
|
||||
- name: open_map
|
||||
description: Open map app on the client device.
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
required: [query]
|
||||
executor: client
|
||||
@@ -14,7 +14,8 @@ event-driven design.
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -59,6 +60,12 @@ class DuplexPipeline:
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
_ASR_DELTA_THROTTLE_MS = 300
|
||||
_LLM_DELTA_THROTTLE_MS = 80
|
||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||
"current_time": {
|
||||
"name": "current_time",
|
||||
@@ -79,7 +86,16 @@ class DuplexPipeline:
|
||||
tts_service: Optional[BaseTTSService] = None,
|
||||
asr_service: Optional[BaseASRService] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
greeting: Optional[str] = None
|
||||
greeting: Optional[str] = None,
|
||||
knowledge_searcher: Optional[
|
||||
Callable[..., Awaitable[List[Dict[str, Any]]]]
|
||||
] = None,
|
||||
tool_resource_resolver: Optional[
|
||||
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||||
] = None,
|
||||
server_tool_executor: Optional[
|
||||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
@@ -96,6 +112,9 @@ class DuplexPipeline:
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
self.track_audio_in = self.TRACK_AUDIO_IN
|
||||
self.track_audio_out = self.TRACK_AUDIO_OUT
|
||||
self.track_control = self.TRACK_CONTROL
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
@@ -117,9 +136,14 @@ class DuplexPipeline:
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service # Will be initialized in start()
|
||||
self._knowledge_searcher = knowledge_searcher
|
||||
self._tool_resource_resolver = tool_resource_resolver
|
||||
self._server_tool_executor = server_tool_executor
|
||||
|
||||
# Track last sent transcript to avoid duplicates
|
||||
self._last_sent_transcript = ""
|
||||
self._pending_transcript_delta: str = ""
|
||||
self._last_transcript_delta_emit_ms: float = 0.0
|
||||
|
||||
# Conversation manager
|
||||
self.conversation = ConversationManager(
|
||||
@@ -153,6 +177,7 @@ class DuplexPipeline:
|
||||
self._outbound_seq = 0
|
||||
self._outbound_task: Optional[asyncio.Task] = None
|
||||
self._drop_outbound_audio = False
|
||||
self._audio_out_frame_buffer: bytes = b""
|
||||
|
||||
# Interruption handling
|
||||
self._interrupt_event = asyncio.Event()
|
||||
@@ -181,14 +206,48 @@ class DuplexPipeline:
|
||||
self._runtime_barge_in_min_duration_ms: Optional[int] = None
|
||||
self._runtime_knowledge: Dict[str, Any] = {}
|
||||
self._runtime_knowledge_base_id: Optional[str] = None
|
||||
self._runtime_tools: List[Any] = []
|
||||
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
|
||||
self._runtime_tools: List[Any] = list(raw_default_tools)
|
||||
self._runtime_tool_executor: Dict[str, str] = {}
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
self._pending_client_tool_call_ids: set[str] = set()
|
||||
self._next_seq: Optional[Callable[[], int]] = None
|
||||
self._local_seq: int = 0
|
||||
|
||||
# Cross-service correlation IDs
|
||||
self._turn_count: int = 0
|
||||
self._response_count: int = 0
|
||||
self._tts_count: int = 0
|
||||
self._utterance_count: int = 0
|
||||
self._current_turn_id: Optional[str] = None
|
||||
self._current_utterance_id: Optional[str] = None
|
||||
self._current_response_id: Optional[str] = None
|
||||
self._current_tts_id: Optional[str] = None
|
||||
self._pending_llm_delta: str = ""
|
||||
self._last_llm_delta_emit_ms: float = 0.0
|
||||
|
||||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||||
|
||||
if self._server_tool_executor is None:
|
||||
if self._tool_resource_resolver:
|
||||
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return await execute_server_tool(
|
||||
call,
|
||||
tool_resource_fetcher=self._tool_resource_resolver,
|
||||
)
|
||||
|
||||
self._server_tool_executor = _executor
|
||||
else:
|
||||
self._server_tool_executor = execute_server_tool
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
|
||||
"""Use session-scoped monotonic sequence provider for envelope events."""
|
||||
self._next_seq = provider
|
||||
|
||||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Apply runtime overrides from WS session.start metadata.
|
||||
@@ -276,6 +335,136 @@ class DuplexPipeline:
|
||||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||||
|
||||
def resolved_runtime_config(self) -> Dict[str, Any]:
|
||||
"""Return current effective runtime configuration without secrets."""
|
||||
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||||
llm_base_url = (
|
||||
self._runtime_llm.get("baseUrl")
|
||||
or settings.llm_api_url
|
||||
or self._default_llm_base_url(llm_provider)
|
||||
)
|
||||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||||
if not output_mode:
|
||||
output_mode = "audio" if self._tts_output_enabled() else "text"
|
||||
|
||||
return {
|
||||
"output": {"mode": output_mode},
|
||||
"services": {
|
||||
"llm": {
|
||||
"provider": llm_provider,
|
||||
"model": str(self._runtime_llm.get("model") or settings.llm_model),
|
||||
"baseUrl": llm_base_url,
|
||||
},
|
||||
"asr": {
|
||||
"provider": asr_provider,
|
||||
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
|
||||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||||
},
|
||||
"tts": {
|
||||
"enabled": self._tts_output_enabled(),
|
||||
"provider": tts_provider,
|
||||
"model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
|
||||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||||
},
|
||||
},
|
||||
"tools": {
|
||||
"allowlist": self._resolved_tool_allowlist(),
|
||||
},
|
||||
"tracks": {
|
||||
"audio_in": self.track_audio_in,
|
||||
"audio_out": self.track_audio_out,
|
||||
"control": self.track_control,
|
||||
},
|
||||
}
|
||||
|
||||
def _next_event_seq(self) -> int:
|
||||
if self._next_seq:
|
||||
return self._next_seq()
|
||||
self._local_seq += 1
|
||||
return self._local_seq
|
||||
|
||||
def _event_source(self, event_type: str) -> str:
|
||||
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
|
||||
return "asr"
|
||||
if event_type.startswith("assistant.response."):
|
||||
return "llm"
|
||||
if event_type.startswith("assistant.tool_"):
|
||||
return "tool"
|
||||
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
|
||||
return "tts"
|
||||
return "system"
|
||||
|
||||
def _new_id(self, prefix: str, counter: int) -> str:
|
||||
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _start_turn(self) -> str:
|
||||
self._turn_count += 1
|
||||
self._current_turn_id = self._new_id("turn", self._turn_count)
|
||||
self._current_utterance_id = None
|
||||
self._current_response_id = None
|
||||
self._current_tts_id = None
|
||||
return self._current_turn_id
|
||||
|
||||
def _start_response(self) -> str:
|
||||
self._response_count += 1
|
||||
self._current_response_id = self._new_id("resp", self._response_count)
|
||||
self._current_tts_id = None
|
||||
return self._current_response_id
|
||||
|
||||
def _start_tts(self) -> str:
|
||||
self._tts_count += 1
|
||||
self._current_tts_id = self._new_id("tts", self._tts_count)
|
||||
return self._current_tts_id
|
||||
|
||||
def _finalize_utterance(self) -> str:
|
||||
if self._current_utterance_id:
|
||||
return self._current_utterance_id
|
||||
self._utterance_count += 1
|
||||
self._current_utterance_id = self._new_id("utt", self._utterance_count)
|
||||
if not self._current_turn_id:
|
||||
self._start_turn()
|
||||
return self._current_utterance_id
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
track_id = event.get("trackId")
|
||||
if not track_id:
|
||||
if source == "asr":
|
||||
track_id = self.track_audio_in
|
||||
elif source in {"llm", "tts", "tool"}:
|
||||
track_id = self.track_audio_out
|
||||
else:
|
||||
track_id = self.track_control
|
||||
|
||||
data = event.get("data")
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
if self._current_turn_id:
|
||||
data.setdefault("turn_id", self._current_turn_id)
|
||||
if self._current_utterance_id:
|
||||
data.setdefault("utterance_id", self._current_utterance_id)
|
||||
if self._current_response_id:
|
||||
data.setdefault("response_id", self._current_response_id)
|
||||
if self._current_tts_id:
|
||||
data.setdefault("tts_id", self._current_tts_id)
|
||||
|
||||
for k, v in event.items():
|
||||
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
|
||||
continue
|
||||
data.setdefault(k, v)
|
||||
|
||||
event["sessionId"] = self.session_id
|
||||
event["seq"] = self._next_event_seq()
|
||||
event["source"] = source
|
||||
event["trackId"] = track_id
|
||||
event["data"] = data
|
||||
return event
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||||
if isinstance(value, bool):
|
||||
@@ -295,6 +484,18 @@ class DuplexPipeline:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
@staticmethod
|
||||
def _is_llm_provider_supported(provider: Any) -> bool:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
@staticmethod
|
||||
def _default_llm_base_url(provider: Any) -> Optional[str]:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
if normalized == "siliconflow":
|
||||
return "https://api.siliconflow.cn/v1"
|
||||
return None
|
||||
|
||||
def _tts_output_enabled(self) -> bool:
|
||||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||||
if enabled is not None:
|
||||
@@ -370,20 +571,25 @@ class DuplexPipeline:
|
||||
try:
|
||||
# Connect LLM service
|
||||
if not self.llm_service:
|
||||
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
|
||||
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
|
||||
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||||
llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key
|
||||
llm_base_url = (
|
||||
self._runtime_llm.get("baseUrl")
|
||||
or settings.llm_api_url
|
||||
or self._default_llm_base_url(llm_provider)
|
||||
)
|
||||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||||
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
|
||||
|
||||
if llm_provider == "openai" and llm_api_key:
|
||||
if self._is_llm_provider_supported(llm_provider) and llm_api_key:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=llm_api_key,
|
||||
base_url=llm_base_url,
|
||||
model=llm_model,
|
||||
knowledge_config=self._resolved_knowledge_config(),
|
||||
knowledge_searcher=self._knowledge_searcher,
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||
@@ -399,20 +605,22 @@ class DuplexPipeline:
|
||||
if tts_output_enabled:
|
||||
if not self.tts_service:
|
||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key
|
||||
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
|
||||
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
||||
self.tts_service = OpenAICompatibleTTSService(
|
||||
api_key=tts_api_key,
|
||||
api_url=tts_api_url,
|
||||
voice=tts_voice,
|
||||
model=tts_model,
|
||||
model=tts_model or "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)")
|
||||
logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=tts_voice,
|
||||
@@ -435,21 +643,23 @@ class DuplexPipeline:
|
||||
# Connect ASR service
|
||||
if not self.asr_service:
|
||||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
|
||||
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
|
||||
asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key
|
||||
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
|
||||
asr_model = self._runtime_asr.get("model") or settings.asr_model
|
||||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||
|
||||
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
|
||||
self.asr_service = OpenAICompatibleASRService(
|
||||
api_key=asr_api_key,
|
||||
model=asr_model,
|
||||
api_url=asr_api_url,
|
||||
model=asr_model or "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate=settings.sample_rate,
|
||||
interim_interval_ms=asr_interim_interval,
|
||||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||
on_transcript=self._on_transcript_callback
|
||||
)
|
||||
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)")
|
||||
logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})")
|
||||
else:
|
||||
self.asr_service = BufferedASRService(
|
||||
sample_rate=settings.sample_rate
|
||||
@@ -472,11 +682,13 @@ class DuplexPipeline:
|
||||
greeting_to_speak = generated_greeting
|
||||
self.conversation.greeting = generated_greeting
|
||||
if greeting_to_speak:
|
||||
self._start_turn()
|
||||
self._start_response()
|
||||
await self._send_event(
|
||||
ev(
|
||||
"assistant.response.final",
|
||||
text=greeting_to_speak,
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
),
|
||||
priority=20,
|
||||
)
|
||||
@@ -494,10 +706,58 @@ class DuplexPipeline:
|
||||
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
|
||||
|
||||
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
|
||||
await self._enqueue_outbound("event", event, priority)
|
||||
await self._enqueue_outbound("event", self._envelope_event(event), priority)
|
||||
|
||||
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
|
||||
await self._enqueue_outbound("audio", pcm_bytes, priority)
|
||||
if not pcm_bytes:
|
||||
return
|
||||
self._audio_out_frame_buffer += pcm_bytes
|
||||
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
|
||||
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
|
||||
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
|
||||
await self._enqueue_outbound("audio", frame, priority)
|
||||
|
||||
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
|
||||
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
|
||||
if not self._audio_out_frame_buffer:
|
||||
return
|
||||
tail = self._audio_out_frame_buffer
|
||||
self._audio_out_frame_buffer = b""
|
||||
if len(tail) < self._PCM_FRAME_BYTES:
|
||||
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
|
||||
await self._enqueue_outbound("audio", tail, priority)
|
||||
|
||||
async def _emit_transcript_delta(self, text: str) -> None:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"transcript.delta",
|
||||
trackId=self.track_audio_in,
|
||||
text=text,
|
||||
)
|
||||
},
|
||||
priority=30,
|
||||
)
|
||||
|
||||
async def _emit_llm_delta(self, text: str) -> None:
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.track_audio_out,
|
||||
text=text,
|
||||
)
|
||||
},
|
||||
priority=20,
|
||||
)
|
||||
|
||||
async def _flush_pending_llm_delta(self) -> None:
|
||||
if not self._pending_llm_delta:
|
||||
return
|
||||
chunk = self._pending_llm_delta
|
||||
self._pending_llm_delta = ""
|
||||
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
|
||||
await self._emit_llm_delta(chunk)
|
||||
|
||||
async def _outbound_loop(self) -> None:
|
||||
"""Single sender loop that enforces priority for interrupt events."""
|
||||
@@ -546,13 +806,13 @@ class DuplexPipeline:
|
||||
|
||||
# Emit VAD event
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"trackId": self.track_audio_in,
|
||||
"probability": probability
|
||||
})
|
||||
await self._send_event(
|
||||
ev(
|
||||
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_in,
|
||||
probability=probability,
|
||||
),
|
||||
priority=30,
|
||||
@@ -661,6 +921,9 @@ class DuplexPipeline:
|
||||
# Cancel any current speaking
|
||||
await self._stop_current_speech()
|
||||
|
||||
self._start_turn()
|
||||
self._finalize_utterance()
|
||||
|
||||
# Start new turn
|
||||
await self.conversation.end_user_turn(text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||
@@ -683,24 +946,45 @@ class DuplexPipeline:
|
||||
if text == self._last_sent_transcript and not is_final:
|
||||
return
|
||||
|
||||
now_ms = time.monotonic() * 1000.0
|
||||
self._last_sent_transcript = text
|
||||
|
||||
# Send transcript event to client
|
||||
await self._send_event({
|
||||
if is_final:
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = 0.0
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"transcript.final" if is_final else "transcript.delta",
|
||||
trackId=self.session_id,
|
||||
"transcript.final",
|
||||
trackId=self.track_audio_in,
|
||||
text=text,
|
||||
)
|
||||
}, priority=30)
|
||||
},
|
||||
priority=30,
|
||||
)
|
||||
logger.debug(f"Sent transcript (final): {text[:50]}...")
|
||||
return
|
||||
|
||||
self._pending_transcript_delta = text
|
||||
should_emit = (
|
||||
self._last_transcript_delta_emit_ms <= 0.0
|
||||
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
|
||||
)
|
||||
if should_emit and self._pending_transcript_delta:
|
||||
delta = self._pending_transcript_delta
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = now_ms
|
||||
await self._emit_transcript_delta(delta)
|
||||
|
||||
if not is_final:
|
||||
logger.info(f"[ASR] ASR interim: {text[:100]}")
|
||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||
logger.debug(f"Sent transcript (interim): {text[:50]}...")
|
||||
|
||||
async def _on_speech_start(self) -> None:
|
||||
"""Handle user starting to speak."""
|
||||
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
|
||||
self._start_turn()
|
||||
self._finalize_utterance()
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
@@ -779,6 +1063,7 @@ class DuplexPipeline:
|
||||
return
|
||||
|
||||
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
|
||||
self._finalize_utterance()
|
||||
|
||||
# For ASR backends that already emitted final via callback,
|
||||
# avoid duplicating transcript.final on EOU.
|
||||
@@ -786,7 +1071,7 @@ class DuplexPipeline:
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"transcript.final",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_in,
|
||||
text=user_text,
|
||||
)
|
||||
}, priority=25)
|
||||
@@ -794,6 +1079,8 @@ class DuplexPipeline:
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = 0.0
|
||||
self._asr_capture_active = False
|
||||
self._pending_speech_audio = b""
|
||||
|
||||
@@ -881,6 +1168,23 @@ class DuplexPipeline:
|
||||
result[name] = executor
|
||||
return result
|
||||
|
||||
def _resolved_tool_allowlist(self) -> List[str]:
|
||||
names: set[str] = set()
|
||||
for item in self._runtime_tools:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
names.add(name)
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.add(str(fn.get("name")).strip())
|
||||
elif item.get("name"):
|
||||
names.add(str(item.get("name")).strip())
|
||||
return sorted([name for name in names if name])
|
||||
|
||||
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
|
||||
fn = tool_call.get("function")
|
||||
if isinstance(fn, dict):
|
||||
@@ -894,6 +1198,44 @@ class DuplexPipeline:
|
||||
# Default to server execution unless explicitly marked as client.
|
||||
return "server"
|
||||
|
||||
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
fn = tool_call.get("function")
|
||||
if not isinstance(fn, dict):
|
||||
return {}
|
||||
raw = fn.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return parsed if isinstance(parsed, dict) else {"raw": raw}
|
||||
except Exception:
|
||||
return {"raw": raw}
|
||||
return {}
|
||||
|
||||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||||
status_code = int(status.get("code") or 0) if status else 0
|
||||
status_message = str(status.get("message") or "") if status else ""
|
||||
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||||
tool_name = str(result.get("name") or "unknown_tool")
|
||||
ok = bool(200 <= status_code < 300)
|
||||
retryable = status_code >= 500 or status_code in {429, 408}
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
if not ok:
|
||||
error = {
|
||||
"code": status_code or 500,
|
||||
"message": status_message or "tool_execution_failed",
|
||||
"retryable": retryable,
|
||||
}
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"ok": ok,
|
||||
"error": error,
|
||||
"status": {"code": status_code, "message": status_message},
|
||||
}
|
||||
|
||||
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
|
||||
tool_name = str(result.get("name") or "unknown_tool")
|
||||
call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||||
@@ -904,12 +1246,17 @@ class DuplexPipeline:
|
||||
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
|
||||
f"status={status_code} {status_message}".strip()
|
||||
)
|
||||
normalized = self._normalize_tool_result(result)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_result",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
source=source,
|
||||
tool_call_id=normalized["tool_call_id"],
|
||||
tool_name=normalized["tool_name"],
|
||||
ok=normalized["ok"],
|
||||
error=normalized["error"],
|
||||
result=result,
|
||||
)
|
||||
},
|
||||
@@ -927,6 +1274,9 @@ class DuplexPipeline:
|
||||
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||||
if not call_id:
|
||||
continue
|
||||
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
|
||||
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
|
||||
continue
|
||||
if call_id in self._completed_tool_call_ids:
|
||||
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
|
||||
continue
|
||||
@@ -972,6 +1322,7 @@ class DuplexPipeline:
|
||||
}
|
||||
finally:
|
||||
self._pending_tool_waiters.pop(call_id, None)
|
||||
self._pending_client_tool_call_ids.discard(call_id)
|
||||
|
||||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||
if isinstance(item, LLMStreamEvent):
|
||||
@@ -998,6 +1349,11 @@ class DuplexPipeline:
|
||||
user_text: User's transcribed text
|
||||
"""
|
||||
try:
|
||||
if not self._current_turn_id:
|
||||
self._start_turn()
|
||||
if not self._current_utterance_id:
|
||||
self._finalize_utterance()
|
||||
self._start_response()
|
||||
# Start latency tracking
|
||||
self._turn_start_time = time.time()
|
||||
self._first_audio_sent = False
|
||||
@@ -1012,6 +1368,8 @@ class DuplexPipeline:
|
||||
self._drop_outbound_audio = False
|
||||
|
||||
first_audio_sent = False
|
||||
self._pending_llm_delta = ""
|
||||
self._last_llm_delta_emit_ms = 0.0
|
||||
for _ in range(max_rounds):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
@@ -1028,6 +1386,7 @@ class DuplexPipeline:
|
||||
|
||||
event = self._normalize_stream_event(raw_event)
|
||||
if event.type == "tool_call":
|
||||
await self._flush_pending_llm_delta()
|
||||
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||||
if not tool_call:
|
||||
continue
|
||||
@@ -1045,11 +1404,19 @@ class DuplexPipeline:
|
||||
f"executor={executor} args={args_preview}"
|
||||
)
|
||||
tool_calls.append(enriched_tool_call)
|
||||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||||
if executor == "client" and call_id:
|
||||
self._pending_client_tool_call_ids.add(call_id)
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.tool_call",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
tool_call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
executor=executor,
|
||||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||||
tool_call=enriched_tool_call,
|
||||
)
|
||||
},
|
||||
@@ -1071,19 +1438,13 @@ class DuplexPipeline:
|
||||
round_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.delta",
|
||||
trackId=self.session_id,
|
||||
text=text_chunk,
|
||||
)
|
||||
},
|
||||
# Keep delta/final on the same event priority so FIFO seq
|
||||
# preserves stream order (avoid late-delta after final).
|
||||
priority=20,
|
||||
)
|
||||
self._pending_llm_delta += text_chunk
|
||||
now_ms = time.monotonic() * 1000.0
|
||||
if (
|
||||
self._last_llm_delta_emit_ms <= 0.0
|
||||
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
|
||||
):
|
||||
await self._flush_pending_llm_delta()
|
||||
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
@@ -1112,11 +1473,12 @@ class DuplexPipeline:
|
||||
|
||||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
@@ -1130,6 +1492,7 @@ class DuplexPipeline:
|
||||
)
|
||||
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
await self._flush_pending_llm_delta()
|
||||
if (
|
||||
self._tts_output_enabled()
|
||||
and remaining_text
|
||||
@@ -1137,11 +1500,12 @@ class DuplexPipeline:
|
||||
and not self._interrupt_event.is_set()
|
||||
):
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
@@ -1172,7 +1536,7 @@ class DuplexPipeline:
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
execute_server_tool(call),
|
||||
self._server_tool_executor(call),
|
||||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -1204,11 +1568,12 @@ class DuplexPipeline:
|
||||
]
|
||||
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self._flush_pending_llm_delta()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"assistant.response.final",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
text=full_response,
|
||||
)
|
||||
},
|
||||
@@ -1217,10 +1582,11 @@ class DuplexPipeline:
|
||||
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self._flush_audio_out_frames(priority=50)
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.end",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=10)
|
||||
|
||||
@@ -1241,6 +1607,8 @@ class DuplexPipeline:
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
self._current_response_id = None
|
||||
self._current_tts_id = None
|
||||
|
||||
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
|
||||
"""
|
||||
@@ -1277,7 +1645,7 @@ class DuplexPipeline:
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"metrics.ttfb",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
}, priority=25)
|
||||
@@ -1354,10 +1722,11 @@ class DuplexPipeline:
|
||||
first_audio_sent = False
|
||||
|
||||
# Send track start event
|
||||
self._start_tts()
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=10)
|
||||
|
||||
@@ -1379,7 +1748,7 @@ class DuplexPipeline:
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"metrics.ttfb",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
}, priority=25)
|
||||
@@ -1391,10 +1760,11 @@ class DuplexPipeline:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Send track end event
|
||||
await self._flush_audio_out_frames(priority=50)
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"output.audio.end",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=10)
|
||||
|
||||
@@ -1422,13 +1792,14 @@ class DuplexPipeline:
|
||||
self._interrupt_event.set()
|
||||
self._is_bot_speaking = False
|
||||
self._drop_outbound_audio = True
|
||||
self._audio_out_frame_buffer = b""
|
||||
|
||||
# Send interrupt event to client IMMEDIATELY
|
||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||
await self._send_event({
|
||||
**ev(
|
||||
"response.interrupted",
|
||||
trackId=self.session_id,
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
}, priority=0)
|
||||
|
||||
@@ -1455,6 +1826,7 @@ class DuplexPipeline:
|
||||
async def _stop_current_speech(self) -> None:
|
||||
"""Stop any current speech task."""
|
||||
self._drop_outbound_audio = True
|
||||
self._audio_out_frame_buffer = b""
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
self._interrupt_event.set()
|
||||
self._current_turn_task.cancel()
|
||||
|
||||
244
engine/core/history_bridge.py
Normal file
244
engine/core/history_bridge.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Async history bridge for non-blocking transcript persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HistoryTranscriptJob:
|
||||
call_id: str
|
||||
turn_index: int
|
||||
speaker: str
|
||||
content: str
|
||||
start_ms: int
|
||||
end_ms: int
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class SessionHistoryBridge:
|
||||
"""Session-scoped buffered history writer with background retries."""
|
||||
|
||||
_STOP_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
history_writer: Any,
|
||||
enabled: bool,
|
||||
queue_max_size: int,
|
||||
retry_max_attempts: int,
|
||||
retry_backoff_sec: float,
|
||||
finalize_drain_timeout_sec: float,
|
||||
):
|
||||
self._history_writer = history_writer
|
||||
self._enabled = bool(enabled and history_writer is not None)
|
||||
self._queue_max_size = max(1, int(queue_max_size))
|
||||
self._retry_max_attempts = max(0, int(retry_max_attempts))
|
||||
self._retry_backoff_sec = max(0.0, float(retry_backoff_sec))
|
||||
self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec))
|
||||
|
||||
self._call_id: Optional[str] = None
|
||||
self._turn_index: int = 0
|
||||
self._started_mono: Optional[float] = None
|
||||
self._finalized: bool = False
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._finalize_lock = asyncio.Lock()
|
||||
self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def call_id(self) -> Optional[str]:
|
||||
return self._call_id
|
||||
|
||||
async def start_call(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str,
|
||||
) -> Optional[str]:
|
||||
"""Create remote call record and start background worker."""
|
||||
if not self._enabled or self._call_id:
|
||||
return self._call_id
|
||||
|
||||
call_id = await self._history_writer.create_call_record(
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
source=source,
|
||||
)
|
||||
if not call_id:
|
||||
return None
|
||||
|
||||
self._call_id = str(call_id)
|
||||
self._turn_index = 0
|
||||
self._finalized = False
|
||||
self._started_mono = time.monotonic()
|
||||
self._ensure_worker()
|
||||
return self._call_id
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
if self._started_mono is None:
|
||||
return 0
|
||||
return max(0, int((time.monotonic() - self._started_mono) * 1000))
|
||||
|
||||
def enqueue_turn(self, *, role: str, text: str) -> bool:
|
||||
"""Queue one transcript write without blocking the caller."""
|
||||
if not self._enabled or not self._call_id or self._finalized:
|
||||
return False
|
||||
|
||||
content = str(text or "").strip()
|
||||
if not content:
|
||||
return False
|
||||
|
||||
speaker = "human" if str(role or "").strip().lower() == "user" else "ai"
|
||||
end_ms = self.elapsed_ms()
|
||||
estimated_duration_ms = max(300, min(12000, len(content) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
job = _HistoryTranscriptJob(
|
||||
call_id=self._call_id,
|
||||
turn_index=self._turn_index,
|
||||
speaker=speaker,
|
||||
content=content,
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._turn_index += 1
|
||||
self._ensure_worker()
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"History queue full; dropping transcript call_id={} turn={}",
|
||||
self._call_id,
|
||||
job.turn_index,
|
||||
)
|
||||
return False
|
||||
|
||||
async def finalize(self, *, status: str) -> bool:
|
||||
"""Finalize history record once; waits briefly for queue drain."""
|
||||
if not self._enabled or not self._call_id:
|
||||
return False
|
||||
|
||||
async with self._finalize_lock:
|
||||
if self._finalized:
|
||||
return True
|
||||
|
||||
await self._drain_queue()
|
||||
ok = await self._history_writer.finalize_call_record(
|
||||
call_id=self._call_id,
|
||||
status=status,
|
||||
duration_seconds=self.duration_seconds(),
|
||||
)
|
||||
if ok:
|
||||
self._finalized = True
|
||||
await self._stop_worker()
|
||||
return ok
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop worker task and release queue resources."""
|
||||
await self._stop_worker()
|
||||
|
||||
def duration_seconds(self) -> int:
|
||||
if self._started_mono is None:
|
||||
return 0
|
||||
return max(0, int(time.monotonic() - self._started_mono))
|
||||
|
||||
def _ensure_worker(self) -> None:
|
||||
if self._worker_task and not self._worker_task.done():
|
||||
return
|
||||
self._worker_task = asyncio.create_task(self._worker_loop())
|
||||
|
||||
async def _drain_queue(self) -> None:
|
||||
if self._finalize_drain_timeout_sec <= 0:
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec)
|
||||
|
||||
async def _stop_worker(self) -> None:
|
||||
task = self._worker_task
|
||||
if not task:
|
||||
return
|
||||
if task.done():
|
||||
self._worker_task = None
|
||||
return
|
||||
|
||||
sent = False
|
||||
try:
|
||||
self._queue.put_nowait(self._STOP_SENTINEL)
|
||||
sent = True
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
if not sent:
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=1.5)
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._worker_task = None
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
try:
|
||||
if item is self._STOP_SENTINEL:
|
||||
return
|
||||
|
||||
assert isinstance(item, _HistoryTranscriptJob)
|
||||
await self._write_with_retry(item)
|
||||
except Exception as exc:
|
||||
logger.warning("History worker write failed unexpectedly: {}", exc)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool:
|
||||
for attempt in range(self._retry_max_attempts + 1):
|
||||
ok = await self._history_writer.add_transcript(
|
||||
call_id=job.call_id,
|
||||
turn_index=job.turn_index,
|
||||
speaker=job.speaker,
|
||||
content=job.content,
|
||||
start_ms=job.start_ms,
|
||||
end_ms=job.end_ms,
|
||||
duration_ms=job.duration_ms,
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
|
||||
if attempt >= self._retry_max_attempts:
|
||||
logger.warning(
|
||||
"History write dropped after retries call_id={} turn={}",
|
||||
job.call_id,
|
||||
job.turn_index,
|
||||
)
|
||||
return False
|
||||
|
||||
if self._retry_backoff_sec > 0:
|
||||
await asyncio.sleep(self._retry_backoff_sec * (2**attempt))
|
||||
return False
|
||||
17
engine/core/ports/__init__.py
Normal file
17
engine/core/ports/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Port interfaces for engine-side integration boundaries."""
|
||||
|
||||
from core.ports.backend import (
|
||||
AssistantConfigProvider,
|
||||
BackendGateway,
|
||||
HistoryWriter,
|
||||
KnowledgeSearcher,
|
||||
ToolResourceResolver,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssistantConfigProvider",
|
||||
"BackendGateway",
|
||||
"HistoryWriter",
|
||||
"KnowledgeSearcher",
|
||||
"ToolResourceResolver",
|
||||
]
|
||||
84
engine/core/ports/backend.py
Normal file
84
engine/core/ports/backend.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Backend integration ports.
|
||||
|
||||
These interfaces define the boundary between engine runtime logic and
|
||||
backend-side capabilities (config lookup, history persistence, retrieval,
|
||||
and tool resource discovery).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
|
||||
class AssistantConfigProvider(Protocol):
|
||||
"""Port for loading trusted assistant runtime configuration."""
|
||||
|
||||
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch assistant configuration payload."""
|
||||
|
||||
|
||||
class HistoryWriter(Protocol):
|
||||
"""Port for persisting call and transcript history."""
|
||||
|
||||
async def create_call_record(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
"""Create a call record and return backend call ID."""
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Append one transcript turn segment."""
|
||||
|
||||
async def finalize_call_record(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
"""Finalize a call record."""
|
||||
|
||||
|
||||
class KnowledgeSearcher(Protocol):
|
||||
"""Port for RAG / knowledge retrieval operations."""
|
||||
|
||||
async def search_knowledge_context(
|
||||
self,
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search a knowledge source and return ranked snippets."""
|
||||
|
||||
|
||||
class ToolResourceResolver(Protocol):
|
||||
"""Port for resolving tool metadata/configuration."""
|
||||
|
||||
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch tool resource configuration."""
|
||||
|
||||
|
||||
class BackendGateway(
|
||||
AssistantConfigProvider,
|
||||
HistoryWriter,
|
||||
KnowledgeSearcher,
|
||||
ToolResourceResolver,
|
||||
Protocol,
|
||||
):
|
||||
"""Composite backend gateway interface used by engine services."""
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import (
|
||||
create_history_call_record,
|
||||
add_history_transcript,
|
||||
finalize_history_call_record,
|
||||
)
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from core.conversation import ConversationTurn
|
||||
from core.history_bridge import SessionHistoryBridge
|
||||
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
||||
from app.config import settings
|
||||
from services.base import LLMMessage
|
||||
@@ -49,7 +46,39 @@ class Session:
|
||||
Uses full duplex voice conversation pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
AUDIO_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
_CLIENT_METADATA_OVERRIDES = {
|
||||
"firstTurnMode",
|
||||
"greeting",
|
||||
"generatedOpenerEnabled",
|
||||
"systemPrompt",
|
||||
"output",
|
||||
"bargeIn",
|
||||
"knowledge",
|
||||
"knowledgeBaseId",
|
||||
"history",
|
||||
"userId",
|
||||
"assistantId",
|
||||
"source",
|
||||
}
|
||||
_CLIENT_METADATA_ID_KEYS = {
|
||||
"appId",
|
||||
"app_id",
|
||||
"channel",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
transport: BaseTransport,
|
||||
use_duplex: bool = None,
|
||||
backend_gateway: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
@@ -61,12 +90,23 @@ class Session:
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
self._backend_gateway = backend_gateway or build_backend_adapter_from_settings()
|
||||
self._history_bridge = SessionHistoryBridge(
|
||||
history_writer=self._backend_gateway,
|
||||
enabled=settings.history_enabled,
|
||||
queue_max_size=settings.history_queue_max_size,
|
||||
retry_max_attempts=settings.history_retry_max_attempts,
|
||||
retry_backoff_sec=settings.history_retry_backoff_sec,
|
||||
finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec,
|
||||
)
|
||||
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
greeting=settings.duplex_greeting,
|
||||
knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None),
|
||||
tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None),
|
||||
)
|
||||
|
||||
# Session state
|
||||
@@ -78,17 +118,15 @@ class Session:
|
||||
self.authenticated: bool = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
self._history_call_id: Optional[str] = None
|
||||
self._history_turn_index: int = 0
|
||||
self._history_call_started_mono: Optional[float] = None
|
||||
self._history_finalized: bool = False
|
||||
self.current_track_id: str = self.TRACK_CONTROL
|
||||
self._event_seq: int = 0
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self._cleaned_up = False
|
||||
self.workflow_runner: Optional[WorkflowRunner] = None
|
||||
self._workflow_last_user_text: str = ""
|
||||
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
||||
|
||||
self.pipeline.set_event_sequence_provider(self._next_event_seq)
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
@@ -129,13 +167,47 @@ class Session:
|
||||
"client",
|
||||
"Audio received before session.start",
|
||||
"protocol.order",
|
||||
stage="protocol",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
if not audio_bytes:
|
||||
return
|
||||
if len(audio_bytes) % 2 != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Invalid PCM payload: odd number of bytes",
|
||||
"audio.invalid_pcm",
|
||||
stage="audio",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
frame_bytes = self.AUDIO_FRAME_BYTES
|
||||
if len(audio_bytes) % frame_bytes != 0:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Audio frame size must be a multiple of {frame_bytes} bytes (20ms PCM)",
|
||||
"audio.frame_size_mismatch",
|
||||
stage="audio",
|
||||
retryable=False,
|
||||
)
|
||||
return
|
||||
|
||||
for i in range(0, len(audio_bytes), frame_bytes):
|
||||
frame = audio_bytes[i : i + frame_bytes]
|
||||
await self.pipeline.process_audio(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
await self._send_error(
|
||||
"server",
|
||||
f"Audio processing failed: {e}",
|
||||
"audio.processing_failed",
|
||||
stage="audio",
|
||||
retryable=True,
|
||||
)
|
||||
|
||||
async def _handle_v1_message(self, message: Any) -> None:
|
||||
"""Route validated WS v1 message to handlers."""
|
||||
@@ -176,7 +248,7 @@ class Session:
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
await self.pipeline.handle_tool_call_results([item.model_dump() for item in message.results])
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
@@ -198,9 +270,9 @@ class Session:
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
auth_payload = message.auth
|
||||
api_key = auth_payload.apiKey if auth_payload else None
|
||||
jwt = auth_payload.jwt if auth_payload else None
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
@@ -217,10 +289,9 @@ class Session:
|
||||
self.authenticated = True
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
sessionId=self.id,
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
@@ -231,8 +302,12 @@ class Session:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
metadata = message.metadata or {}
|
||||
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
||||
raw_metadata = message.metadata or {}
|
||||
workflow_runtime = self._bootstrap_workflow(raw_metadata)
|
||||
server_runtime = await self._load_server_runtime_metadata(raw_metadata, workflow_runtime)
|
||||
client_runtime = self._sanitize_client_metadata(raw_metadata)
|
||||
metadata = self._merge_runtime_metadata(server_runtime, self._sanitize_untrusted_runtime_metadata(workflow_runtime))
|
||||
metadata = self._merge_runtime_metadata(metadata, client_runtime)
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
await self._start_history_bridge(metadata)
|
||||
@@ -248,28 +323,37 @@ class Session:
|
||||
|
||||
self.state = "accepted"
|
||||
self.ws_state = WsSessionState.ACTIVE
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"session.started",
|
||||
sessionId=self.id,
|
||||
trackId=self.current_track_id,
|
||||
audio=message.audio or {},
|
||||
tracks={
|
||||
"audio_in": self.TRACK_AUDIO_IN,
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
audio=message.audio.model_dump() if message.audio else {},
|
||||
)
|
||||
)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"config.resolved",
|
||||
trackId=self.TRACK_CONTROL,
|
||||
config=self._build_config_resolved(metadata),
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
@@ -285,17 +369,24 @@ class Session:
|
||||
stop_reason = reason or "client_requested"
|
||||
self.state = "hungup"
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"session.stopped",
|
||||
sessionId=self.id,
|
||||
reason=stop_reason,
|
||||
)
|
||||
)
|
||||
await self._finalize_history(status="connected")
|
||||
await self.transport.close()
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||
async def _send_error(
|
||||
self,
|
||||
sender: str,
|
||||
error_message: str,
|
||||
code: str,
|
||||
stage: Optional[str] = None,
|
||||
retryable: Optional[bool] = None,
|
||||
track_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
@@ -304,13 +395,26 @@ class Session:
|
||||
error_message: Error message
|
||||
code: Machine-readable error code
|
||||
"""
|
||||
await self.transport.send_event(
|
||||
resolved_stage = stage or self._infer_error_stage(code)
|
||||
resolved_retryable = retryable if retryable is not None else (resolved_stage in {"asr", "llm", "tts", "tool", "audio"})
|
||||
resolved_track_id = track_id or self._error_track_id(resolved_stage, code)
|
||||
await self._send_event(
|
||||
ev(
|
||||
"error",
|
||||
sender=sender,
|
||||
code=code,
|
||||
message=error_message,
|
||||
trackId=self.current_track_id,
|
||||
stage=resolved_stage,
|
||||
retryable=resolved_retryable,
|
||||
trackId=resolved_track_id,
|
||||
data={
|
||||
"error": {
|
||||
"stage": resolved_stage,
|
||||
"code": code,
|
||||
"message": error_message,
|
||||
"retryable": resolved_retryable,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -329,11 +433,12 @@ class Session:
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self._finalize_history(status="connected")
|
||||
await self.pipeline.cleanup()
|
||||
await self._history_bridge.shutdown()
|
||||
await self.transport.close()
|
||||
|
||||
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
|
||||
"""Initialize backend history call record for this session."""
|
||||
if self._history_call_id:
|
||||
if self._history_bridge.call_id:
|
||||
return
|
||||
|
||||
history_meta: Dict[str, Any] = {}
|
||||
@@ -349,7 +454,7 @@ class Session:
|
||||
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
|
||||
source = str(history_meta.get("source", metadata.get("source", "debug")))
|
||||
|
||||
call_id = await create_history_call_record(
|
||||
call_id = await self._history_bridge.start_call(
|
||||
user_id=user_id,
|
||||
assistant_id=str(assistant_id) if assistant_id else None,
|
||||
source=source,
|
||||
@@ -357,10 +462,6 @@ class Session:
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
self._history_call_id = call_id
|
||||
self._history_call_started_mono = time.monotonic()
|
||||
self._history_turn_index = 0
|
||||
self._history_finalized = False
|
||||
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
||||
|
||||
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
||||
@@ -372,48 +473,11 @@ class Session:
|
||||
elif role == "assistant":
|
||||
await self._maybe_advance_workflow(turn.text.strip())
|
||||
|
||||
if not self._history_call_id:
|
||||
return
|
||||
if not turn.text or not turn.text.strip():
|
||||
return
|
||||
|
||||
role = (turn.role or "").lower()
|
||||
speaker = "human" if role == "user" else "ai"
|
||||
|
||||
end_ms = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
|
||||
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
turn_index = self._history_turn_index
|
||||
await add_history_transcript(
|
||||
call_id=self._history_call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=turn.text.strip(),
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._history_turn_index += 1
|
||||
self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "")
|
||||
|
||||
async def _finalize_history(self, status: str) -> None:
|
||||
"""Finalize history call record once."""
|
||||
if not self._history_call_id or self._history_finalized:
|
||||
return
|
||||
|
||||
duration_seconds = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
|
||||
|
||||
ok = await finalize_history_call_record(
|
||||
call_id=self._history_call_id,
|
||||
status=status,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
if ok:
|
||||
self._history_finalized = True
|
||||
await self._history_bridge.finalize(status=status)
|
||||
|
||||
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse workflow payload and return initial runtime overrides."""
|
||||
@@ -483,10 +547,9 @@ class Session:
|
||||
node = transition.node
|
||||
edge = transition.edge
|
||||
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.edge.taken",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
edgeId=edge.id,
|
||||
fromNodeId=edge.from_node_id,
|
||||
@@ -494,10 +557,9 @@ class Session:
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
nodeName=node.name,
|
||||
@@ -510,10 +572,9 @@ class Session:
|
||||
self.pipeline.apply_runtime_overrides(node_runtime)
|
||||
|
||||
if node.node_type == "tool":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.tool.requested",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
tool=node.tool or {},
|
||||
@@ -522,10 +583,9 @@ class Session:
|
||||
return
|
||||
|
||||
if node.node_type == "human_transfer":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.human_transfer",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
@@ -534,16 +594,77 @@ class Session:
|
||||
return
|
||||
|
||||
if node.node_type == "end":
|
||||
await self.transport.send_event(
|
||||
await self._send_event(
|
||||
ev(
|
||||
"workflow.ended",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_end")
|
||||
|
||||
def _next_event_seq(self) -> int:
|
||||
self._event_seq += 1
|
||||
return self._event_seq
|
||||
|
||||
def _event_source(self, event_type: str) -> str:
|
||||
if event_type.startswith("workflow."):
|
||||
return "system"
|
||||
if event_type.startswith("session.") or event_type.startswith("hello.") or event_type == "heartbeat":
|
||||
return "system"
|
||||
if event_type == "error":
|
||||
return "system"
|
||||
return "system"
|
||||
|
||||
def _infer_error_stage(self, code: str) -> str:
|
||||
normalized = str(code or "").strip().lower()
|
||||
if normalized.startswith("audio."):
|
||||
return "audio"
|
||||
if normalized.startswith("tool."):
|
||||
return "tool"
|
||||
if normalized.startswith("asr."):
|
||||
return "asr"
|
||||
if normalized.startswith("llm."):
|
||||
return "llm"
|
||||
if normalized.startswith("tts."):
|
||||
return "tts"
|
||||
return "protocol"
|
||||
|
||||
def _error_track_id(self, stage: str, code: str) -> str:
|
||||
if stage in {"audio", "asr"}:
|
||||
return self.TRACK_AUDIO_IN
|
||||
if stage in {"llm", "tts", "tool"}:
|
||||
return self.TRACK_AUDIO_OUT
|
||||
if str(code or "").strip().lower().startswith("auth."):
|
||||
return self.TRACK_CONTROL
|
||||
return self.TRACK_CONTROL
|
||||
|
||||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = str(event.get("type") or "")
|
||||
source = str(event.get("source") or self._event_source(event_type))
|
||||
track_id = event.get("trackId") or self.TRACK_CONTROL
|
||||
|
||||
data = event.get("data")
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
for k, v in event.items():
|
||||
if k in {"type", "timestamp", "sessionId", "seq", "source", "trackId", "data"}:
|
||||
continue
|
||||
data.setdefault(k, v)
|
||||
|
||||
event["sessionId"] = self.id
|
||||
event["seq"] = self._next_event_seq()
|
||||
event["source"] = source
|
||||
event["trackId"] = track_id
|
||||
event["data"] = data
|
||||
return event
|
||||
|
||||
async def _send_event(self, event: Dict[str, Any]) -> None:
|
||||
await self.transport.send_event(self._envelope_event(event))
|
||||
|
||||
async def send_heartbeat(self) -> None:
|
||||
await self._send_event(ev("heartbeat", trackId=self.TRACK_CONTROL))
|
||||
|
||||
async def _workflow_llm_route(
|
||||
self,
|
||||
node: WorkflowNodeDef,
|
||||
@@ -629,6 +750,137 @@ class Session:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
async def _load_server_runtime_metadata(
|
||||
self,
|
||||
client_metadata: Dict[str, Any],
|
||||
workflow_runtime: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Load trusted runtime metadata from backend assistant config."""
|
||||
assistant_id = (
|
||||
workflow_runtime.get("assistantId")
|
||||
or client_metadata.get("assistantId")
|
||||
or client_metadata.get("appId")
|
||||
or client_metadata.get("app_id")
|
||||
)
|
||||
if assistant_id is None:
|
||||
return {}
|
||||
|
||||
provider = getattr(self._backend_gateway, "fetch_assistant_config", None)
|
||||
if not callable(provider):
|
||||
return {}
|
||||
|
||||
payload = await provider(str(assistant_id).strip())
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
|
||||
assistant_cfg: Dict[str, Any] = {}
|
||||
session_start_cfg = payload.get("sessionStartMetadata")
|
||||
if isinstance(session_start_cfg, dict):
|
||||
assistant_cfg.update(session_start_cfg)
|
||||
if isinstance(payload.get("assistant"), dict):
|
||||
assistant_cfg.update(payload.get("assistant"))
|
||||
elif not assistant_cfg:
|
||||
assistant_cfg = payload
|
||||
|
||||
if not isinstance(assistant_cfg, dict):
|
||||
return {}
|
||||
|
||||
runtime: Dict[str, Any] = {}
|
||||
passthrough_keys = {
|
||||
"firstTurnMode",
|
||||
"generatedOpenerEnabled",
|
||||
"output",
|
||||
"bargeIn",
|
||||
"knowledgeBaseId",
|
||||
"knowledge",
|
||||
"history",
|
||||
"userId",
|
||||
"source",
|
||||
"tools",
|
||||
"services",
|
||||
"configVersionId",
|
||||
"config_version_id",
|
||||
}
|
||||
for key in passthrough_keys:
|
||||
if key in assistant_cfg:
|
||||
runtime[key] = assistant_cfg[key]
|
||||
|
||||
if assistant_cfg.get("systemPrompt") is not None:
|
||||
runtime["systemPrompt"] = str(assistant_cfg.get("systemPrompt") or "")
|
||||
elif assistant_cfg.get("prompt") is not None:
|
||||
runtime["systemPrompt"] = str(assistant_cfg.get("prompt") or "")
|
||||
|
||||
if assistant_cfg.get("greeting") is not None:
|
||||
runtime["greeting"] = assistant_cfg.get("greeting")
|
||||
elif assistant_cfg.get("opener") is not None:
|
||||
runtime["greeting"] = assistant_cfg.get("opener")
|
||||
|
||||
resolved_assistant_id = (
|
||||
assistant_cfg.get("assistantId")
|
||||
or payload.get("assistantId")
|
||||
or assistant_id
|
||||
)
|
||||
runtime["assistantId"] = str(resolved_assistant_id)
|
||||
|
||||
if runtime.get("configVersionId") is None and payload.get("configVersionId") is not None:
|
||||
runtime["configVersionId"] = payload.get("configVersionId")
|
||||
if runtime.get("configVersionId") is None and payload.get("config_version_id") is not None:
|
||||
runtime["configVersionId"] = payload.get("config_version_id")
|
||||
|
||||
if runtime.get("config_version_id") is not None and runtime.get("configVersionId") is None:
|
||||
runtime["configVersionId"] = runtime.get("config_version_id")
|
||||
return runtime
|
||||
|
||||
def _sanitize_untrusted_runtime_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitize untrusted metadata sources.
|
||||
|
||||
This keeps only a small override whitelist and stable config ID fields.
|
||||
"""
|
||||
if not isinstance(metadata, dict):
|
||||
return {}
|
||||
|
||||
sanitized: Dict[str, Any] = {}
|
||||
for key in self._CLIENT_METADATA_ID_KEYS:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
for key in self._CLIENT_METADATA_OVERRIDES:
|
||||
if key in metadata:
|
||||
sanitized[key] = metadata[key]
|
||||
|
||||
return sanitized
|
||||
|
||||
def _sanitize_client_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply client metadata whitelist and remove forbidden secrets."""
|
||||
sanitized = self._sanitize_untrusted_runtime_metadata(metadata)
|
||||
if isinstance(metadata.get("services"), dict):
|
||||
logger.warning(
|
||||
"Session {} provided metadata.services from client; client-side service config is ignored",
|
||||
self.id,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
def _build_config_resolved(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build public resolved config payload (secrets removed)."""
|
||||
system_prompt = str(metadata.get("systemPrompt") or self.pipeline.conversation.system_prompt or "")
|
||||
prompt_hash = hashlib.sha256(system_prompt.encode("utf-8")).hexdigest() if system_prompt else None
|
||||
runtime = self.pipeline.resolved_runtime_config()
|
||||
|
||||
return {
|
||||
"appId": metadata.get("appId") or metadata.get("app_id") or metadata.get("assistantId"),
|
||||
"channel": metadata.get("channel"),
|
||||
"configVersionId": metadata.get("configVersionId") or metadata.get("config_version_id"),
|
||||
"prompt": {"sha256": prompt_hash},
|
||||
"output": runtime.get("output", {}),
|
||||
"services": runtime.get("services", {}),
|
||||
"tools": runtime.get("tools", {}),
|
||||
"tracks": {
|
||||
"audio_in": self.TRACK_AUDIO_IN,
|
||||
"audio_out": self.TRACK_AUDIO_OUT,
|
||||
"control": self.TRACK_CONTROL,
|
||||
},
|
||||
}
|
||||
|
||||
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort extraction of a JSON object from freeform text."""
|
||||
try:
|
||||
|
||||
@@ -4,11 +4,13 @@ import asyncio
|
||||
import ast
|
||||
import operator
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.backend_client import fetch_tool_resource
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
|
||||
ToolResourceFetcher = Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||||
|
||||
_BIN_OPS = {
|
||||
ast.Add: operator.add,
|
||||
@@ -170,11 +172,21 @@ def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Default tool resource resolver via backend adapter."""
|
||||
adapter = build_backend_adapter_from_settings()
|
||||
return await adapter.fetch_tool_resource(tool_id)
|
||||
|
||||
|
||||
async def execute_server_tool(
|
||||
tool_call: Dict[str, Any],
|
||||
tool_resource_fetcher: Optional[ToolResourceFetcher] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a server-side tool and return normalized result payload."""
|
||||
call_id = str(tool_call.get("id") or "").strip()
|
||||
tool_name = _extract_tool_name(tool_call)
|
||||
args = _extract_tool_args(tool_call)
|
||||
resource_fetcher = tool_resource_fetcher or fetch_tool_resource
|
||||
|
||||
if tool_name == "calculator":
|
||||
expression = str(args.get("expression") or "").strip()
|
||||
@@ -257,7 +269,7 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
resource = await fetch_tool_resource(tool_name)
|
||||
resource = await resource_fetcher(tool_name)
|
||||
if resource and str(resource.get("category") or "") == "query":
|
||||
method = str(resource.get("http_method") or "GET").strip().upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
|
||||
47
engine/docs/backend_integration.md
Normal file
47
engine/docs/backend_integration.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# Backend Integration and History Bridge
|
||||
|
||||
This engine uses adapter-based backend integration so core runtime logic can run
|
||||
with or without an external backend service.
|
||||
|
||||
## Runtime Modes
|
||||
|
||||
Configure with environment variables:
|
||||
|
||||
- `BACKEND_MODE=auto|http|disabled`
|
||||
- `BACKEND_URL`
|
||||
- `BACKEND_TIMEOUT_SEC`
|
||||
- `HISTORY_ENABLED=true|false`
|
||||
|
||||
Mode behavior:
|
||||
|
||||
- `auto`: use HTTP backend adapter only when `BACKEND_URL` is set.
|
||||
- `http`: force HTTP backend adapter (falls back to null adapter when URL is missing).
|
||||
- `disabled`: force null adapter and run engine-only.
|
||||
|
||||
## Architecture
|
||||
|
||||
- Ports: `core/ports/backend.py`
|
||||
- Adapters: `app/backend_adapters.py`
|
||||
- Compatibility wrappers: `app/backend_client.py`
|
||||
|
||||
`Session` and `DuplexPipeline` receive backend capabilities via injected adapter
|
||||
methods instead of hard-coding backend client imports.
|
||||
|
||||
## Async History Writes
|
||||
|
||||
Session history persistence is handled by `core/history_bridge.py`.
|
||||
|
||||
Design:
|
||||
|
||||
- transcript writes are queued with `put_nowait` (non-blocking turn path)
|
||||
- background worker drains queue
|
||||
- failed writes retry with exponential backoff
|
||||
- finalize waits briefly for queue drain before sending call finalize
|
||||
- finalize is idempotent
|
||||
|
||||
Related settings:
|
||||
|
||||
- `HISTORY_QUEUE_MAX_SIZE`
|
||||
- `HISTORY_RETRY_MAX_ATTEMPTS`
|
||||
- `HISTORY_RETRY_BACKOFF_SEC`
|
||||
- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC`
|
||||
@@ -2,6 +2,11 @@
|
||||
|
||||
This document defines the public WebSocket protocol for the `/ws` endpoint.
|
||||
|
||||
Validation policy:
|
||||
- WS v1 JSON control messages are validated strictly.
|
||||
- Unknown top-level fields are rejected for all defined client message types.
|
||||
- `hello.version` is fixed to `"v1"`.
|
||||
|
||||
## Transport
|
||||
|
||||
- A single WebSocket connection carries:
|
||||
@@ -52,43 +57,26 @@ Rules:
|
||||
"channels": 1
|
||||
},
|
||||
"metadata": {
|
||||
"appId": "assistant_123",
|
||||
"channel": "web",
|
||||
"configVersionId": "cfg_20260217_01",
|
||||
"client": "web-debug",
|
||||
"output": {
|
||||
"mode": "audio"
|
||||
},
|
||||
"systemPrompt": "You are concise.",
|
||||
"greeting": "Hi, how can I help?",
|
||||
"services": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o-mini",
|
||||
"apiKey": "sk-...",
|
||||
"baseUrl": "https://api.openai.com/v1"
|
||||
},
|
||||
"asr": {
|
||||
"provider": "openai_compatible",
|
||||
"model": "FunAudioLLM/SenseVoiceSmall",
|
||||
"apiKey": "sf-...",
|
||||
"interimIntervalMs": 500,
|
||||
"minAudioMs": 300
|
||||
},
|
||||
"tts": {
|
||||
"enabled": true,
|
||||
"provider": "openai_compatible",
|
||||
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
||||
"apiKey": "sf-...",
|
||||
"voice": "anna",
|
||||
"speed": 1.0
|
||||
}
|
||||
}
|
||||
"greeting": "Hi, how can I help?"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`metadata.services` is optional. If omitted, server defaults to environment configuration.
|
||||
Rules:
|
||||
- Client-side `metadata.services` is ignored.
|
||||
- Service config (including secrets) is resolved server-side (env/backend).
|
||||
- Client should pass stable IDs (`appId`, `channel`, `configVersionId`) plus small runtime overrides (e.g. `output`, `bargeIn`, greeting/prompt style hints).
|
||||
|
||||
Text-only mode:
|
||||
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`.
|
||||
- Set `metadata.output.mode = "text"`.
|
||||
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
|
||||
|
||||
### `input.text`
|
||||
@@ -121,6 +109,7 @@ Text-only mode:
|
||||
### `tool_call.results`
|
||||
|
||||
Client tool execution results returned to server.
|
||||
Only needed when `assistant.tool_call.executor == "client"` (default execution is server-side).
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -138,21 +127,36 @@ Client tool execution results returned to server.
|
||||
|
||||
## Server -> Client Events
|
||||
|
||||
All server events include:
|
||||
All server events include an envelope:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "event.name",
|
||||
"timestamp": 1730000000000
|
||||
"timestamp": 1730000000000,
|
||||
"sessionId": "sess_xxx",
|
||||
"seq": 42,
|
||||
"source": "asr",
|
||||
"trackId": "audio_in",
|
||||
"data": {}
|
||||
}
|
||||
```
|
||||
|
||||
Envelope notes:
|
||||
- `seq` is monotonically increasing within one session (for replay/resume).
|
||||
- `source` is one of: `asr | llm | tts | tool | system | client | server`.
|
||||
- For `assistant.tool_result`, `source` may be `client` or `server` to indicate execution side.
|
||||
- `data` is structured payload; legacy top-level fields are kept for compatibility.
|
||||
|
||||
Common events:
|
||||
|
||||
- `hello.ack`
|
||||
- Fields: `sessionId`, `version`
|
||||
- `session.started`
|
||||
- Fields: `sessionId`, `trackId`, `audio`
|
||||
- Fields: `sessionId`, `trackId`, `tracks`, `audio`
|
||||
- `config.resolved`
|
||||
- Fields: `sessionId`, `trackId`, `config`
|
||||
- Sent immediately after `session.started`.
|
||||
- Contains effective model/voice/output/tool allowlist/prompt hash, and never includes secrets.
|
||||
- `session.stopped`
|
||||
- Fields: `sessionId`, `reason`
|
||||
- `heartbeat`
|
||||
@@ -169,9 +173,10 @@ Common events:
|
||||
- `assistant.response.final`
|
||||
- Fields: `trackId`, `text`
|
||||
- `assistant.tool_call`
|
||||
- Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`)
|
||||
- Fields: `trackId`, `tool_call`, `tool_call_id`, `tool_name`, `arguments`, `executor`, `timeout_ms`
|
||||
- `assistant.tool_result`
|
||||
- Fields: `trackId`, `source`, `result`
|
||||
- Fields: `trackId`, `source`, `result`, `tool_call_id`, `tool_name`, `ok`, `error`
|
||||
- `error`: `{ code, message, retryable }` when `ok=false`
|
||||
- `output.audio.start`
|
||||
- Fields: `trackId`
|
||||
- `output.audio.end`
|
||||
@@ -182,16 +187,54 @@ Common events:
|
||||
- Fields: `trackId`, `latencyMs`
|
||||
- `error`
|
||||
- Fields: `sender`, `code`, `message`, `trackId`
|
||||
- `trackId` convention:
|
||||
- `audio_in` for `stage in {audio, asr}`
|
||||
- `audio_out` for `stage in {llm, tts, tool}`
|
||||
- `control` otherwise (including protocol/auth errors)
|
||||
|
||||
Track IDs (MVP fixed values):
|
||||
- `audio_in`: ASR/VAD input-side events (`input.*`, `transcript.*`)
|
||||
- `audio_out`: assistant output-side events (`assistant.*`, `output.audio.*`, `response.interrupted`, `metrics.ttfb`)
|
||||
- `control`: session/control events (`session.*`, `hello.*`, `error`, `config.resolved`)
|
||||
|
||||
Correlation IDs (`event.data`):
|
||||
- `turn_id`: one user-assistant interaction turn.
|
||||
- `utterance_id`: one ASR final utterance.
|
||||
- `response_id`: one assistant response generation.
|
||||
- `tool_call_id`: one tool invocation.
|
||||
- `tts_id`: one TTS playback segment.
|
||||
|
||||
## Binary Audio Frames
|
||||
|
||||
After `session.started`, client may send binary PCM chunks continuously.
|
||||
|
||||
Recommended format:
|
||||
- 16-bit signed little-endian PCM.
|
||||
- 1 channel.
|
||||
- 16000 Hz.
|
||||
- 20ms frames (640 bytes) preferred.
|
||||
MVP fixed format:
|
||||
- 16-bit signed little-endian PCM (`pcm_s16le`)
|
||||
- mono (1 channel)
|
||||
- 16000 Hz
|
||||
- 20ms frame = 640 bytes
|
||||
|
||||
Framing rules:
|
||||
- Binary audio frame unit is 640 bytes.
|
||||
- A WS binary message may carry one or multiple complete 640-byte frames.
|
||||
- Non-640-multiple payloads are rejected as `audio.frame_size_mismatch`; that WS message is dropped (no partial buffering/reassembly).
|
||||
|
||||
TTS boundary events:
|
||||
- `output.audio.start` and `output.audio.end` mark assistant playback boundaries.
|
||||
|
||||
## Event Throttling
|
||||
|
||||
To keep client rendering and server load stable, v1 applies/recommends:
|
||||
- `transcript.delta`: merge to ~200-500ms cadence (server default: 300ms).
|
||||
- `assistant.response.delta`: merge to ~50-100ms cadence (server default: 80ms).
|
||||
- Metrics streams (if enabled beyond `metrics.ttfb`): emit every ~500-1000ms.
|
||||
|
||||
## Error Structure
|
||||
|
||||
`error` keeps legacy top-level fields (`code`, `message`) and adds structured info:
|
||||
- `stage`: `protocol | asr | llm | tts | tool | audio`
|
||||
- `retryable`: boolean
|
||||
- `data.error`: `{ stage, code, message, retryable }`
|
||||
|
||||
## Compatibility
|
||||
|
||||
|
||||
520
engine/docs/ws_v1_schema_zh.md
Normal file
520
engine/docs/ws_v1_schema_zh.md
Normal file
@@ -0,0 +1,520 @@
|
||||
# WS v1 协议完整说明(中文)
|
||||
|
||||
本文档描述 `/ws` 端点的 WebSocket v1 协议,覆盖:
|
||||
- 客户端输入(JSON 文本消息 + 二进制音频);
|
||||
- 服务端输出(JSON 事件 + 二进制音频);
|
||||
- 每个参数的类型、约束、含义与使用方式;
|
||||
- 握手顺序、状态机、错误语义与实现细节。
|
||||
|
||||
实现对照来源:
|
||||
- `models/ws_v1.py`
|
||||
- `core/session.py`
|
||||
- `core/duplex_pipeline.py`
|
||||
- `app/main.py`
|
||||
|
||||
---
|
||||
|
||||
## 1. 传输与基础规则
|
||||
|
||||
- 连接地址:`ws://<host>/ws`
|
||||
- 单连接双通道承载:
|
||||
- 文本帧:JSON 控制消息(严格校验 schema)
|
||||
- 二进制帧:原始 PCM 音频
|
||||
- JSON 校验策略:
|
||||
- 所有已定义客户端消息都 `extra="forbid"`,即不允许未声明字段;
|
||||
- `hello.version` 固定必须是 `"v1"`;
|
||||
- 缺失 `type` 或未知 `type` 会返回协议错误。
|
||||
|
||||
---
|
||||
|
||||
## 2. 状态机与消息顺序
|
||||
|
||||
### 2.1 服务端状态
|
||||
|
||||
- `WAIT_HELLO`:等待 `hello`
|
||||
- `WAIT_START`:已通过握手,等待 `session.start`
|
||||
- `ACTIVE`:会话运行中,可收发文本/音频
|
||||
- `STOPPED`:会话结束
|
||||
|
||||
### 2.2 正确顺序
|
||||
|
||||
1. 客户端发送 `hello`
|
||||
2. 服务端返回 `hello.ack`
|
||||
3. 客户端发送 `session.start`
|
||||
4. 服务端返回 `session.started`
|
||||
5. 客户端可持续发送:
|
||||
- 二进制音频
|
||||
- `input.text`(可选)
|
||||
- `response.cancel`(可选)
|
||||
- `tool_call.results`(可选)
|
||||
6. 客户端发送 `session.stop` 或直接断开连接
|
||||
|
||||
顺序错误会返回 `error`,`code = "protocol.order"`。
|
||||
|
||||
---
|
||||
|
||||
## 3. 客户端 -> 服务端消息(输入)
|
||||
|
||||
## 3.1 `hello`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
"auth": {
|
||||
"apiKey": "optional-api-key",
|
||||
"jwt": "optional-jwt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||
|---|---|---|---|---|---|
|
||||
| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 |
|
||||
| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 |
|
||||
| `auth` | object \| null | 否 | 仅允许 `apiKey`、`jwt` | 认证载荷 | 认证策略由服务端配置决定 |
|
||||
| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 |
|
||||
| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 |
|
||||
|
||||
认证行为:
|
||||
- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。
|
||||
- 若 `WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY`:`auth.apiKey` 或 `auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。
|
||||
|
||||
## 3.2 `session.start`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": 16000,
|
||||
"channels": 1
|
||||
},
|
||||
"metadata": {
|
||||
"appId": "assistant_123",
|
||||
"channel": "web",
|
||||
"configVersionId": "cfg_20260217_01",
|
||||
"client": "web-debug",
|
||||
"output": {
|
||||
"mode": "audio"
|
||||
},
|
||||
"systemPrompt": "你是简洁助手",
|
||||
"greeting": "你好,我能帮你什么?"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||
|---|---|---|---|---|---|
|
||||
| `type` | string | 是 | 固定 `"session.start"` | 启动会话 | 握手后第二阶段消息 |
|
||||
| `audio` | object \| null | 否 | 仅支持固定值 | 音频格式描述 | 仅用于声明;MVP 实际只接受固定 PCM |
|
||||
| `audio.encoding` | string | 否 | 固定 `"pcm_s16le"` | 编码格式 | 非该值会在模型校验层报错 |
|
||||
| `audio.sample_rate_hz` | number | 否 | 固定 `16000` | 采样率 | 16kHz |
|
||||
| `audio.channels` | number | 否 | 固定 `1` | 声道数 | 单声道 |
|
||||
| `metadata` | object \| null | 否 | 任意对象(会被白名单过滤) | 运行时配置 | 用于 app/channel/提示词/输出模式等覆盖 |
|
||||
|
||||
`metadata` 白名单策略(关键):
|
||||
- 允许透传的标识字段(ID 类):
|
||||
- `appId` / `app_id`
|
||||
- `channel`
|
||||
- `configVersionId` / `config_version_id`
|
||||
- 允许透传的覆盖字段:
|
||||
- `firstTurnMode`
|
||||
- `greeting`
|
||||
- `generatedOpenerEnabled`
|
||||
- `systemPrompt`
|
||||
- `output`
|
||||
- `bargeIn`
|
||||
- `knowledge`
|
||||
- `knowledgeBaseId`
|
||||
- `history`
|
||||
- `userId`
|
||||
- `assistantId`
|
||||
- `source`
|
||||
- 客户端传入 `metadata.services` 会被忽略(服务端会记录 warning),服务配置由后端/环境变量决定。
|
||||
|
||||
`output.mode` 用法:
|
||||
- `"audio"`(默认语音输出)
|
||||
- `"text"`(纯文本输出)
|
||||
- 纯文本模式下仍会收到 `assistant.response.delta/final`;
|
||||
- 不会收到 TTS 音频帧与 `output.audio.start/end`。
|
||||
|
||||
## 3.3 `input.text`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "input.text",
|
||||
"text": "你能做什么?"
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||
|---|---|---|---|---|---|
|
||||
| `type` | string | 是 | 固定 `"input.text"` | 文本输入 | 跳过 ASR,直接触发 LLM 回答 |
|
||||
| `text` | string | 是 | 非空字符串为佳 | 用户文本 | 用于文本聊天或调试 |
|
||||
|
||||
## 3.4 `response.cancel`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "response.cancel",
|
||||
"graceful": false
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 必填 | 默认值 | 含义 | 使用说明 |
|
||||
|---|---|---|---|---|---|
|
||||
| `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 |
|
||||
| `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 |
|
||||
|
||||
## 3.5 `tool_call.results`
|
||||
|
||||
仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "tool_call.results",
|
||||
"results": [
|
||||
{
|
||||
"tool_call_id": "call_abc123",
|
||||
"name": "weather",
|
||||
"output": { "temp_c": 21, "condition": "sunny" },
|
||||
"status": { "code": 200, "message": "ok" }
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||
|---|---|---|---|---|---|
|
||||
| `type` | string | 是 | 固定 `"tool_call.results"` | 工具执行回传 | 客户端工具结果上送 |
|
||||
| `results` | array | 否 | 默认为空数组 | 多个工具结果 | 可批量回传 |
|
||||
| `results[].tool_call_id` | string | 是 | 任意字符串 | 工具调用ID | 必须与 `assistant.tool_call.tool_call_id` 对应 |
|
||||
| `results[].name` | string | 是 | 任意字符串 | 工具名 | 建议与请求一致 |
|
||||
| `results[].output` | any | 否 | 任意 JSON | 工具输出 | 供模型后续组织回答 |
|
||||
| `results[].status` | object | 是 | 包含 `code`、`message` | 执行状态 | 用于判定成功/失败 |
|
||||
| `results[].status.code` | number | 是 | HTTP 风格状态码 | 状态码 | `200-299` 判定成功 |
|
||||
| `results[].status.message` | string | 是 | 任意字符串 | 状态描述 | 例如 `"ok"` / `"timeout"` |
|
||||
|
||||
处理规则:
|
||||
- 未请求过的 `tool_call_id` 会被忽略(防止伪造/串话);
|
||||
- 重复回传会被忽略;
|
||||
- 超时未回传会由服务端合成超时结果(`504`)。
|
||||
|
||||
## 3.6 `session.stop`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "session.stop",
|
||||
"reason": "client_disconnect"
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||
|---|---|---|---|---|---|
|
||||
| `type` | string | 是 | 固定 `"session.stop"` | 结束会话 | 正常结束推荐发送 |
|
||||
| `reason` | string \| null | 否 | 任意字符串 | 结束原因 | 服务端会回传到 `session.stopped.reason` |
|
||||
|
||||
---
|
||||
|
||||
## 4. 二进制音频输入(客户端 -> 服务端)
|
||||
|
||||
在 `session.started` 之后可持续发送二进制音频。
|
||||
|
||||
固定格式(MVP):
|
||||
- 编码:`pcm_s16le`
|
||||
- 采样率:`16000`
|
||||
- 声道:`1`
|
||||
- 帧长:20ms = `640 bytes`
|
||||
|
||||
分包规则:
|
||||
- 单个 WebSocket 二进制消息可包含 1 帧或多帧;
|
||||
- 长度必须是 `640` 的整数倍;
|
||||
- 不是 `640` 倍数会触发 `audio.frame_size_mismatch`,该消息整包丢弃;
|
||||
- 奇数字节长度会触发 `audio.invalid_pcm`。
|
||||
|
||||
---
|
||||
|
||||
## 5. 服务端 -> 客户端事件(输出)
|
||||
|
||||
所有 JSON 事件都包含统一包络字段。
|
||||
|
||||
## 5.1 统一包络(Envelope)
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "event.name",
|
||||
"timestamp": 1730000000000,
|
||||
"sessionId": "sess_xxx",
|
||||
"seq": 42,
|
||||
"source": "asr",
|
||||
"trackId": "audio_in",
|
||||
"data": {}
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 含义 | 使用说明 |
|
||||
|---|---|---|---|
|
||||
| `type` | string | 事件类型 | 见下方事件清单 |
|
||||
| `timestamp` | number | 事件时间戳(毫秒) | 由 `ev()` 生成 |
|
||||
| `sessionId` | string | 会话ID | 同一连接固定 |
|
||||
| `seq` | number | 单会话递增序号 | 可用于重放、去重、排序 |
|
||||
| `source` | string | 事件来源 | 常见:`asr`/`llm`/`tts`/`tool`/`system`/`client`/`server` |
|
||||
| `trackId` | string | 事件轨道 | 常用:`audio_in`/`audio_out`/`control` |
|
||||
| `data` | object | 结构化数据 | 顶层业务字段会镜像进 `data` 以兼容旧客户端 |
|
||||
|
||||
关联ID(在 `data` 内自动注入,存在时):
|
||||
- `turn_id`:一次用户-助手对话轮次
|
||||
- `utterance_id`:一次用户语音话语
|
||||
- `response_id`:一次助手生成响应
|
||||
- `tool_call_id`:一次工具调用
|
||||
- `tts_id`:一次 TTS 播放段
|
||||
|
||||
## 5.2 事件类型与参数
|
||||
|
||||
### 5.2.1 会话与控制类
|
||||
|
||||
1. `hello.ack`
|
||||
- 关键字段:`version`
|
||||
- 含义:握手成功,应紧接着发送 `session.start`
|
||||
|
||||
2. `session.started`
|
||||
- 关键字段:
|
||||
- `trackId`
|
||||
- `tracks.audio_in`
|
||||
- `tracks.audio_out`
|
||||
- `tracks.control`
|
||||
- `audio`(回显客户端声明的音频元信息)
|
||||
- 含义:会话进入 ACTIVE,可发音频/文本
|
||||
|
||||
3. `config.resolved`
|
||||
- 关键字段:
|
||||
- `config.appId`
|
||||
- `config.channel`
|
||||
- `config.configVersionId`
|
||||
- `config.prompt.sha256`
|
||||
- `config.output`
|
||||
- `config.services`(去密钥后的有效服务配置)
|
||||
- `config.tools.allowlist`
|
||||
- `config.tracks`
|
||||
- 含义:服务端最终生效配置快照,便于前端展示与排错
|
||||
|
||||
4. `heartbeat`
|
||||
- 关键字段:无业务字段(仅 envelope)
|
||||
- 含义:保活心跳
|
||||
- 默认间隔:`heartbeat_interval_sec`(默认 50s)
|
||||
|
||||
5. `session.stopped`
|
||||
- 关键字段:`reason`
|
||||
- 含义:会话结束确认
|
||||
|
||||
6. `error`
|
||||
- 关键字段:
|
||||
- `sender`
|
||||
- `code`
|
||||
- `message`
|
||||
- `stage`
|
||||
- `retryable`
|
||||
- `trackId`
|
||||
- `data.error`(结构化错误镜像)
|
||||
- 含义:统一错误事件
|
||||
|
||||
### 5.2.2 识别与输入侧(ASR/VAD)
|
||||
|
||||
1. `input.speech_started`
|
||||
- 字段:`probability`
|
||||
- 含义:检测到语音开始
|
||||
|
||||
2. `input.speech_stopped`
|
||||
- 字段:`probability`
|
||||
- 含义:检测到语音结束
|
||||
|
||||
3. `transcript.delta`
|
||||
- 字段:`text`
|
||||
- 含义:ASR 增量识别文本(节流发送)
|
||||
|
||||
4. `transcript.final`
|
||||
- 字段:`text`
|
||||
- 含义:ASR 最终识别文本
|
||||
|
||||
### 5.2.3 输出侧(LLM/TTS/Tool)
|
||||
|
||||
1. `assistant.response.delta`
|
||||
- 字段:`text`
|
||||
- 含义:助手增量文本输出(节流发送)
|
||||
|
||||
2. `assistant.response.final`
|
||||
- 字段:`text`
|
||||
- 含义:助手完整文本输出
|
||||
|
||||
3. `assistant.tool_call`
|
||||
- 字段:
|
||||
- `tool_call_id`
|
||||
- `tool_name`
|
||||
- `arguments`(对象)
|
||||
- `executor`(`client` 或 `server`)
|
||||
- `timeout_ms`
|
||||
- `tool_call`(完整工具调用对象)
|
||||
- 含义:通知客户端发生工具调用(用于可视化或客户端执行)
|
||||
|
||||
4. `assistant.tool_result`
|
||||
- 字段:
|
||||
- `source`(`client` 或 `server`)
|
||||
- `tool_call_id`
|
||||
- `tool_name`
|
||||
- `ok`(boolean)
|
||||
- `error`(失败时 `{code,message,retryable}`)
|
||||
- `result`(原始结果对象)
|
||||
- 含义:工具调用结果回执
|
||||
|
||||
5. `output.audio.start`
|
||||
- 含义:TTS 音频输出开始边界
|
||||
|
||||
6. `output.audio.end`
|
||||
- 含义:TTS 音频输出结束边界
|
||||
|
||||
7. `response.interrupted`
|
||||
- 含义:当前回答被打断(barge-in 或 cancel)
|
||||
|
||||
8. `metrics.ttfb`
|
||||
- 字段:`latencyMs`
|
||||
- 含义:首包音频时延(TTFB)
|
||||
|
||||
### 5.2.4 工作流扩展事件(可选)
|
||||
|
||||
若 `metadata.workflow` 生效,会额外出现:
|
||||
- `workflow.started`
|
||||
- `workflow.node.entered`
|
||||
- `workflow.edge.taken`
|
||||
- `workflow.tool.requested`
|
||||
- `workflow.human_transfer`
|
||||
- `workflow.ended`
|
||||
|
||||
这些事件用于外部可视化工作流状态,不影响基础语音会话协议。
|
||||
|
||||
---
|
||||
|
||||
## 6. 服务端二进制音频输出(服务端 -> 客户端)
|
||||
|
||||
- 音频为 PCM 二进制帧;
|
||||
- 发送单位对齐到 `640 bytes`(不足会补零后发送);
|
||||
- 前端通常结合 `output.audio.start/end` 做播放边界控制;
|
||||
- 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。
|
||||
|
||||
---
|
||||
|
||||
## 7. 错误模型与常见错误码
|
||||
|
||||
统一结构(`error` 事件):
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "error",
|
||||
"sender": "client",
|
||||
"code": "protocol.invalid_message",
|
||||
"message": "Invalid message: ...",
|
||||
"stage": "protocol",
|
||||
"retryable": false,
|
||||
"trackId": "control",
|
||||
"data": {
|
||||
"error": {
|
||||
"stage": "protocol",
|
||||
"code": "protocol.invalid_message",
|
||||
"message": "Invalid message: ...",
|
||||
"retryable": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
字段语义:
|
||||
- `sender`:错误来源角色(如 `client` / `server` / `auth`)
|
||||
- `code`:机器可读错误码
|
||||
- `message`:人类可读描述
|
||||
- `stage`:阶段(`protocol|audio|asr|llm|tts|tool`)
|
||||
- `retryable`:是否建议重试
|
||||
- `trackId`:错误归属轨道
|
||||
|
||||
常见错误码:
|
||||
- `protocol.invalid_json`
|
||||
- `protocol.invalid_message`
|
||||
- `protocol.order`
|
||||
- `protocol.version_unsupported`
|
||||
- `protocol.unsupported`
|
||||
- `auth.invalid_api_key`
|
||||
- `auth.required`
|
||||
- `audio.invalid_pcm`
|
||||
- `audio.frame_size_mismatch`
|
||||
- `audio.processing_failed`
|
||||
- `server.internal`
|
||||
|
||||
---
|
||||
|
||||
## 8. 心跳与超时
|
||||
|
||||
服务端后台任务逻辑:
|
||||
- 每隔约 5 秒检查一次连接;
|
||||
- 超过 `inactivity_timeout_sec`(默认 60 秒)未收到任何客户端消息则关闭会话;
|
||||
- 每隔 `heartbeat_interval_sec`(默认 50 秒)发送一次 `heartbeat`。
|
||||
|
||||
客户端建议:
|
||||
- 持续上行音频或定期发送轻量文本消息,避免被判定闲置;
|
||||
- 用 `heartbeat` + `seq` 检测连接活性和事件乱序。
|
||||
|
||||
---
|
||||
|
||||
## 9. 实战接入建议
|
||||
|
||||
1. 建连后立即发送 `hello`,收到 `hello.ack` 后再发 `session.start`。
|
||||
2. 语音输入严格按 16k/16bit/mono,并保证每个 WS 二进制消息长度是 `640*n`。
|
||||
3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。
|
||||
4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。
|
||||
5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`。
|
||||
6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`。
|
||||
|
||||
---
|
||||
|
||||
## 10. 最小完整时序示例
|
||||
|
||||
```text
|
||||
Client -> hello
|
||||
Server <- hello.ack
|
||||
Client -> session.start
|
||||
Server <- session.started
|
||||
Server <- config.resolved
|
||||
Client -> (binary pcm frames...)
|
||||
Server <- input.speech_started / transcript.delta / transcript.final
|
||||
Server <- assistant.response.delta / assistant.response.final
|
||||
Server <- output.audio.start
|
||||
Server <- (binary pcm frames...)
|
||||
Server <- output.audio.end
|
||||
Client -> session.stop
|
||||
Server <- session.stopped
|
||||
```
|
||||
|
||||
@@ -59,8 +59,12 @@ class MicrophoneClient:
|
||||
url: str,
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
app_id: str = "assistant_demo",
|
||||
channel: str = "mic_client",
|
||||
config_version_id: str = "local-dev",
|
||||
input_device: int = None,
|
||||
output_device: int = None
|
||||
output_device: int = None,
|
||||
track_debug: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize microphone client.
|
||||
@@ -76,8 +80,12 @@ class MicrophoneClient:
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
self.app_id = app_id
|
||||
self.channel = channel
|
||||
self.config_version_id = config_version_id
|
||||
self.input_device = input_device
|
||||
self.output_device = output_device
|
||||
self.track_debug = track_debug
|
||||
|
||||
# WebSocket connection
|
||||
self.ws = None
|
||||
@@ -107,6 +115,17 @@ class MicrophoneClient:
|
||||
# Verbose mode for streaming LLM responses
|
||||
self.verbose = False
|
||||
|
||||
@staticmethod
|
||||
def _event_ids_suffix(event: dict) -> str:
|
||||
data = event.get("data") if isinstance(event.get("data"), dict) else {}
|
||||
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
|
||||
parts = []
|
||||
for key in keys:
|
||||
value = data.get(key, event.get(key))
|
||||
if value:
|
||||
parts.append(f"{key}={value}")
|
||||
return f" [{' '.join(parts)}]" if parts else ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
@@ -114,20 +133,30 @@ class MicrophoneClient:
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# Send invite command
|
||||
# WS v1 handshake: hello -> session.start
|
||||
await self.send_command({
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"sampleRate": self.sample_rate
|
||||
}
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
})
|
||||
await self.send_command({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": self.sample_rate,
|
||||
"channels": 1,
|
||||
},
|
||||
"metadata": {
|
||||
"appId": self.app_id,
|
||||
"channel": self.channel,
|
||||
"configVersionId": self.config_version_id,
|
||||
},
|
||||
})
|
||||
|
||||
async def send_command(self, cmd: dict) -> None:
|
||||
"""Send JSON command to server."""
|
||||
if self.ws:
|
||||
await self.ws.send(json.dumps(cmd))
|
||||
print(f"→ Command: {cmd.get('command', 'unknown')}")
|
||||
print(f"→ Command: {cmd.get('type', 'unknown')}")
|
||||
|
||||
async def send_chat(self, text: str) -> None:
|
||||
"""Send chat message (text input)."""
|
||||
@@ -136,7 +165,7 @@ class MicrophoneClient:
|
||||
self.first_audio_received = False
|
||||
|
||||
await self.send_command({
|
||||
"command": "chat",
|
||||
"type": "input.text",
|
||||
"text": text
|
||||
})
|
||||
print(f"→ Chat: {text}")
|
||||
@@ -144,13 +173,14 @@ class MicrophoneClient:
|
||||
async def send_interrupt(self) -> None:
|
||||
"""Send interrupt command."""
|
||||
await self.send_command({
|
||||
"command": "interrupt"
|
||||
"type": "response.cancel",
|
||||
"graceful": False,
|
||||
})
|
||||
|
||||
async def send_hangup(self, reason: str = "User quit") -> None:
|
||||
"""Send hangup command."""
|
||||
await self.send_command({
|
||||
"command": "hangup",
|
||||
"type": "session.stop",
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
@@ -295,43 +325,48 @@ class MicrophoneClient:
|
||||
|
||||
async def _handle_event(self, event: dict) -> None:
|
||||
"""Handle incoming event."""
|
||||
event_type = event.get("event", "unknown")
|
||||
event_type = event.get("type", event.get("event", "unknown"))
|
||||
ids = self._event_ids_suffix(event)
|
||||
if self.track_debug:
|
||||
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
||||
|
||||
if event_type == "answer":
|
||||
print("← Session ready!")
|
||||
elif event_type == "speaking":
|
||||
print("← User speech detected")
|
||||
elif event_type == "silence":
|
||||
print("← User silence detected")
|
||||
elif event_type == "transcript":
|
||||
if event_type in {"hello.ack", "session.started"}:
|
||||
print(f"← Session ready!{ids}")
|
||||
elif event_type == "config.resolved":
|
||||
print(f"← Config resolved: {event.get('config', {}).get('output', {})}{ids}")
|
||||
elif event_type == "input.speech_started":
|
||||
print(f"← User speech detected{ids}")
|
||||
elif event_type == "input.speech_stopped":
|
||||
print(f"← User silence detected{ids}")
|
||||
elif event_type in {"transcript", "transcript.delta", "transcript.final"}:
|
||||
# Display user speech transcription
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
is_final = event_type == "transcript.final" or bool(event.get("isFinal"))
|
||||
if is_final:
|
||||
# Clear the interim line and print final
|
||||
print(" " * 80, end="\r") # Clear previous interim text
|
||||
print(f"→ You: {text}")
|
||||
print(f"→ You: {text}{ids}")
|
||||
else:
|
||||
# Interim result - show with indicator (overwrite same line)
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" [listening] {display_text}".ljust(80), end="\r")
|
||||
elif event_type == "ttfb":
|
||||
elif event_type in {"ttfb", "metrics.ttfb"}:
|
||||
# Server-side TTFB event
|
||||
latency_ms = event.get("latencyMs", 0)
|
||||
print(f"← [TTFB] Server reported latency: {latency_ms}ms")
|
||||
elif event_type == "llmResponse":
|
||||
elif event_type in {"llmResponse", "assistant.response.delta", "assistant.response.final"}:
|
||||
# LLM text response
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
is_final = event_type == "assistant.response.final" or bool(event.get("isFinal"))
|
||||
if is_final:
|
||||
# Print final LLM response
|
||||
print(f"← AI: {text}")
|
||||
elif self.verbose:
|
||||
# Show streaming chunks only in verbose mode
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" [streaming] {display_text}")
|
||||
elif event_type == "trackStart":
|
||||
print("← Bot started speaking")
|
||||
print(f" [streaming] {display_text}{ids}")
|
||||
elif event_type in {"trackStart", "output.audio.start"}:
|
||||
print(f"← Bot started speaking{ids}")
|
||||
# IMPORTANT: Accept audio again after trackStart
|
||||
self._discard_audio = False
|
||||
self._audio_sequence += 1
|
||||
@@ -342,13 +377,13 @@ class MicrophoneClient:
|
||||
# Clear any old audio in buffer
|
||||
with self.audio_output_lock:
|
||||
self.audio_output_buffer = b""
|
||||
elif event_type == "trackEnd":
|
||||
print("← Bot finished speaking")
|
||||
elif event_type in {"trackEnd", "output.audio.end"}:
|
||||
print(f"← Bot finished speaking{ids}")
|
||||
# Reset TTFB tracking after response completes
|
||||
self.request_start_time = None
|
||||
self.first_audio_received = False
|
||||
elif event_type == "interrupt":
|
||||
print("← Bot interrupted!")
|
||||
elif event_type in {"interrupt", "response.interrupted"}:
|
||||
print(f"← Bot interrupted!{ids}")
|
||||
# IMPORTANT: Discard all audio until next trackStart
|
||||
self._discard_audio = True
|
||||
# Clear audio buffer immediately
|
||||
@@ -357,12 +392,12 @@ class MicrophoneClient:
|
||||
self.audio_output_buffer = b""
|
||||
print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)")
|
||||
elif event_type == "error":
|
||||
print(f"← Error: {event.get('error')}")
|
||||
elif event_type == "hangup":
|
||||
print(f"← Hangup: {event.get('reason')}")
|
||||
print(f"← Error: {event.get('error')}{ids}")
|
||||
elif event_type in {"hangup", "session.stopped"}:
|
||||
print(f"← Hangup: {event.get('reason')}{ids}")
|
||||
self.running = False
|
||||
else:
|
||||
print(f"← Event: {event_type}")
|
||||
print(f"← Event: {event_type}{ids}")
|
||||
|
||||
async def interactive_mode(self) -> None:
|
||||
"""Run interactive mode for text chat."""
|
||||
@@ -573,6 +608,26 @@ async def main():
|
||||
action="store_true",
|
||||
help="Show streaming LLM response chunks"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--app-id",
|
||||
default="assistant_demo",
|
||||
help="Stable app/assistant identifier for server-side config lookup"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
default="mic_client",
|
||||
help="Client channel name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-version-id",
|
||||
default="local-dev",
|
||||
help="Optional config version identifier"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--track-debug",
|
||||
action="store_true",
|
||||
help="Print event trackId for protocol debugging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -583,8 +638,12 @@ async def main():
|
||||
client = MicrophoneClient(
|
||||
url=args.url,
|
||||
sample_rate=args.sample_rate,
|
||||
app_id=args.app_id,
|
||||
channel=args.channel,
|
||||
config_version_id=args.config_version_id,
|
||||
input_device=args.input_device,
|
||||
output_device=args.output_device
|
||||
output_device=args.output_device,
|
||||
track_debug=args.track_debug,
|
||||
)
|
||||
client.verbose = args.verbose
|
||||
|
||||
|
||||
@@ -52,9 +52,21 @@ if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
|
||||
class SimpleVoiceClient:
|
||||
"""Simple voice client with reliable audio playback."""
|
||||
|
||||
def __init__(self, url: str, sample_rate: int = 16000):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
sample_rate: int = 16000,
|
||||
app_id: str = "assistant_demo",
|
||||
channel: str = "simple_client",
|
||||
config_version_id: str = "local-dev",
|
||||
track_debug: bool = False,
|
||||
):
|
||||
self.url = url
|
||||
self.sample_rate = sample_rate
|
||||
self.app_id = app_id
|
||||
self.channel = channel
|
||||
self.config_version_id = config_version_id
|
||||
self.track_debug = track_debug
|
||||
self.ws = None
|
||||
self.running = False
|
||||
|
||||
@@ -76,6 +88,17 @@ class SimpleVoiceClient:
|
||||
# Interrupt handling - discard audio until next trackStart
|
||||
self._discard_audio = False
|
||||
|
||||
@staticmethod
|
||||
def _event_ids_suffix(event: dict) -> str:
|
||||
data = event.get("data") if isinstance(event.get("data"), dict) else {}
|
||||
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
|
||||
parts = []
|
||||
for key in keys:
|
||||
value = data.get(key, event.get(key))
|
||||
if value:
|
||||
parts.append(f"{key}={value}")
|
||||
return f" [{' '.join(parts)}]" if parts else ""
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
@@ -83,12 +106,25 @@ class SimpleVoiceClient:
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# Send invite
|
||||
# WS v1 handshake: hello -> session.start
|
||||
await self.ws.send(json.dumps({
|
||||
"command": "invite",
|
||||
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
}))
|
||||
print("-> invite")
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": self.sample_rate,
|
||||
"channels": 1,
|
||||
},
|
||||
"metadata": {
|
||||
"appId": self.app_id,
|
||||
"channel": self.channel,
|
||||
"configVersionId": self.config_version_id,
|
||||
},
|
||||
}))
|
||||
print("-> hello/session.start")
|
||||
|
||||
async def send_chat(self, text: str):
|
||||
"""Send chat message."""
|
||||
@@ -96,8 +132,8 @@ class SimpleVoiceClient:
|
||||
self.request_start_time = time.time()
|
||||
self.first_audio_received = False
|
||||
|
||||
await self.ws.send(json.dumps({"command": "chat", "text": text}))
|
||||
print(f"-> chat: {text}")
|
||||
await self.ws.send(json.dumps({"type": "input.text", "text": text}))
|
||||
print(f"-> input.text: {text}")
|
||||
|
||||
def play_audio(self, audio_data: bytes):
|
||||
"""Play audio data immediately."""
|
||||
@@ -152,34 +188,39 @@ class SimpleVoiceClient:
|
||||
else:
|
||||
# JSON event
|
||||
event = json.loads(msg)
|
||||
etype = event.get("event", "?")
|
||||
etype = event.get("type", event.get("event", "?"))
|
||||
ids = self._event_ids_suffix(event)
|
||||
if self.track_debug:
|
||||
print(f"[track-debug] event={etype} trackId={event.get('trackId')}{ids}")
|
||||
|
||||
if etype == "transcript":
|
||||
if etype in {"transcript", "transcript.delta", "transcript.final"}:
|
||||
# User speech transcription
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
is_final = etype == "transcript.final" or bool(event.get("isFinal"))
|
||||
if is_final:
|
||||
print(f"<- You said: {text}")
|
||||
print(f"<- You said: {text}{ids}")
|
||||
else:
|
||||
print(f"<- [listening] {text}", end="\r")
|
||||
elif etype == "ttfb":
|
||||
elif etype in {"ttfb", "metrics.ttfb"}:
|
||||
# Server-side TTFB event
|
||||
latency_ms = event.get("latencyMs", 0)
|
||||
print(f"<- [TTFB] Server reported latency: {latency_ms}ms")
|
||||
elif etype == "trackStart":
|
||||
elif etype in {"trackStart", "output.audio.start"}:
|
||||
# New track starting - accept audio again
|
||||
self._discard_audio = False
|
||||
print(f"<- {etype}")
|
||||
elif etype == "interrupt":
|
||||
print(f"<- {etype}{ids}")
|
||||
elif etype in {"interrupt", "response.interrupted"}:
|
||||
# Interrupt - discard audio until next trackStart
|
||||
self._discard_audio = True
|
||||
print(f"<- {etype} (discarding audio until new track)")
|
||||
elif etype == "hangup":
|
||||
print(f"<- {etype}")
|
||||
print(f"<- {etype}{ids} (discarding audio until new track)")
|
||||
elif etype in {"hangup", "session.stopped"}:
|
||||
print(f"<- {etype}{ids}")
|
||||
self.running = False
|
||||
break
|
||||
elif etype == "config.resolved":
|
||||
print(f"<- config.resolved {event.get('config', {}).get('output', {})}{ids}")
|
||||
else:
|
||||
print(f"<- {etype}")
|
||||
print(f"<- {etype}{ids}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
@@ -270,6 +311,10 @@ async def main():
|
||||
parser.add_argument("--text", help="Send text and play response")
|
||||
parser.add_argument("--list-devices", action="store_true")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000)
|
||||
parser.add_argument("--app-id", default="assistant_demo")
|
||||
parser.add_argument("--channel", default="simple_client")
|
||||
parser.add_argument("--config-version-id", default="local-dev")
|
||||
parser.add_argument("--track-debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -277,7 +322,14 @@ async def main():
|
||||
list_audio_devices()
|
||||
return
|
||||
|
||||
client = SimpleVoiceClient(args.url, args.sample_rate)
|
||||
client = SimpleVoiceClient(
|
||||
args.url,
|
||||
args.sample_rate,
|
||||
app_id=args.app_id,
|
||||
channel=args.channel,
|
||||
config_version_id=args.config_version_id,
|
||||
track_debug=args.track_debug,
|
||||
)
|
||||
await client.run(args.text)
|
||||
|
||||
|
||||
|
||||
@@ -36,8 +36,18 @@ def generate_sine_wave(duration_ms=1000):
|
||||
return audio_data
|
||||
|
||||
|
||||
async def receive_loop(ws, ready_event: asyncio.Event):
|
||||
async def receive_loop(ws, ready_event: asyncio.Event, track_debug: bool = False):
|
||||
"""Listen for incoming messages from the server."""
|
||||
def event_ids_suffix(data):
|
||||
payload = data.get("data") if isinstance(data.get("data"), dict) else {}
|
||||
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
|
||||
parts = []
|
||||
for key in keys:
|
||||
value = payload.get(key, data.get(key))
|
||||
if value:
|
||||
parts.append(f"{key}={value}")
|
||||
return f" [{' '.join(parts)}]" if parts else ""
|
||||
|
||||
print("👂 Listening for server responses...")
|
||||
async for msg in ws:
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
@@ -46,7 +56,10 @@ async def receive_loop(ws, ready_event: asyncio.Event):
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event_type = data.get('type', 'Unknown')
|
||||
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
|
||||
ids = event_ids_suffix(data)
|
||||
print(f"[{timestamp}] 📨 Event: {event_type}{ids} | {msg.data[:150]}...")
|
||||
if track_debug:
|
||||
print(f"[{timestamp}] [track-debug] event={event_type} trackId={data.get('trackId')}{ids}")
|
||||
if event_type == "session.started":
|
||||
ready_event.set()
|
||||
except json.JSONDecodeError:
|
||||
@@ -113,7 +126,7 @@ async def send_sine_loop(ws):
|
||||
print("\n✅ Finished streaming test audio.")
|
||||
|
||||
|
||||
async def run_client(url, file_path=None, use_sine=False):
|
||||
async def run_client(url, file_path=None, use_sine=False, track_debug: bool = False):
|
||||
"""Run the WebSocket test client."""
|
||||
session = aiohttp.ClientSession()
|
||||
try:
|
||||
@@ -121,7 +134,7 @@ async def run_client(url, file_path=None, use_sine=False):
|
||||
async with session.ws_connect(url) as ws:
|
||||
print("✅ Connected!")
|
||||
session_ready = asyncio.Event()
|
||||
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
|
||||
recv_task = asyncio.create_task(receive_loop(ws, session_ready, track_debug=track_debug))
|
||||
|
||||
# Send v1 hello + session.start handshake
|
||||
await ws.send_json({"type": "hello", "version": "v1"})
|
||||
@@ -131,7 +144,12 @@ async def run_client(url, file_path=None, use_sine=False):
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": SAMPLE_RATE,
|
||||
"channels": 1
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"appId": "assistant_demo",
|
||||
"channel": "test_websocket",
|
||||
"configVersionId": "local-dev",
|
||||
},
|
||||
})
|
||||
print("📤 Sent v1 hello/session.start")
|
||||
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
||||
@@ -168,9 +186,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
|
||||
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
|
||||
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
|
||||
parser.add_argument("--track-debug", action="store_true", help="Print event trackId for protocol debugging")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
asyncio.run(run_client(args.url, args.file, args.sine))
|
||||
asyncio.run(run_client(args.url, args.file, args.sine, args.track_debug))
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Client stopped.")
|
||||
|
||||
@@ -57,10 +57,15 @@ class WavFileClient:
|
||||
url: str,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
app_id: str = "assistant_demo",
|
||||
channel: str = "wav_client",
|
||||
config_version_id: str = "local-dev",
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
wait_time: float = 15.0,
|
||||
verbose: bool = False
|
||||
verbose: bool = False,
|
||||
track_debug: bool = False,
|
||||
tail_silence_ms: int = 800,
|
||||
):
|
||||
"""
|
||||
Initialize WAV file client.
|
||||
@@ -77,11 +82,17 @@ class WavFileClient:
|
||||
self.url = url
|
||||
self.input_file = Path(input_file)
|
||||
self.output_file = Path(output_file)
|
||||
self.app_id = app_id
|
||||
self.channel = channel
|
||||
self.config_version_id = config_version_id
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
self.wait_time = wait_time
|
||||
self.verbose = verbose
|
||||
self.track_debug = track_debug
|
||||
self.tail_silence_ms = max(0, int(tail_silence_ms))
|
||||
self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms
|
||||
|
||||
# WebSocket connection
|
||||
self.ws = None
|
||||
@@ -105,6 +116,7 @@ class WavFileClient:
|
||||
self.track_started = False
|
||||
self.track_ended = False
|
||||
self.send_completed = False
|
||||
self.session_ready = False
|
||||
|
||||
# Events log
|
||||
self.events_log = []
|
||||
@@ -125,6 +137,17 @@ class WavFileClient:
|
||||
safe_message = message.encode('ascii', errors='replace').decode('ascii')
|
||||
print(f"{direction} {safe_message}")
|
||||
|
||||
@staticmethod
|
||||
def _event_ids_suffix(event: dict) -> str:
|
||||
data = event.get("data") if isinstance(event.get("data"), dict) else {}
|
||||
keys = ("turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id")
|
||||
parts = []
|
||||
for key in keys:
|
||||
value = data.get(key, event.get(key))
|
||||
if value:
|
||||
parts.append(f"{key}={value}")
|
||||
return f" [{' '.join(parts)}]" if parts else ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
self.log_event("→", f"Connecting to {self.url}...")
|
||||
@@ -132,25 +155,35 @@ class WavFileClient:
|
||||
self.running = True
|
||||
self.log_event("←", "Connected!")
|
||||
|
||||
# Send invite command
|
||||
# WS v1 handshake: hello -> session.start
|
||||
await self.send_command({
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"sampleRate": self.sample_rate
|
||||
}
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
})
|
||||
await self.send_command({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": self.sample_rate,
|
||||
"channels": 1
|
||||
},
|
||||
"metadata": {
|
||||
"appId": self.app_id,
|
||||
"channel": self.channel,
|
||||
"configVersionId": self.config_version_id,
|
||||
},
|
||||
})
|
||||
|
||||
async def send_command(self, cmd: dict) -> None:
|
||||
"""Send JSON command to server."""
|
||||
if self.ws:
|
||||
await self.ws.send(json.dumps(cmd))
|
||||
self.log_event("→", f"Command: {cmd.get('command', 'unknown')}")
|
||||
self.log_event("→", f"Command: {cmd.get('type', 'unknown')}")
|
||||
|
||||
async def send_hangup(self, reason: str = "Session complete") -> None:
|
||||
"""Send hangup command."""
|
||||
await self.send_command({
|
||||
"command": "hangup",
|
||||
"type": "session.stop",
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
@@ -210,6 +243,10 @@ class WavFileClient:
|
||||
end_sample = min(sent_samples + chunk_size, total_samples)
|
||||
chunk = audio_data[sent_samples:end_sample]
|
||||
chunk_bytes = chunk.tobytes()
|
||||
if len(chunk_bytes) % self.frame_bytes != 0:
|
||||
# v1 audio framing requires 640-byte (20ms) PCM units.
|
||||
pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes)
|
||||
chunk_bytes += b"\x00" * pad
|
||||
|
||||
# Send to server
|
||||
if self.ws:
|
||||
@@ -227,6 +264,16 @@ class WavFileClient:
|
||||
# Server expects audio at real-time pace for VAD/ASR to work properly
|
||||
await asyncio.sleep(self.chunk_duration_ms / 1000)
|
||||
|
||||
# Add a short silence tail to help VAD/EOU close the final utterance.
|
||||
if self.tail_silence_ms > 0 and self.ws:
|
||||
tail_frames = max(1, self.tail_silence_ms // 20)
|
||||
silence = b"\x00" * self.frame_bytes
|
||||
for _ in range(tail_frames):
|
||||
await self.ws.send(silence)
|
||||
self.bytes_sent += len(silence)
|
||||
await asyncio.sleep(0.02)
|
||||
self.log_event("→", f"Sent trailing silence: {self.tail_silence_ms}ms")
|
||||
|
||||
self.send_completed = True
|
||||
elapsed = time.time() - self.send_start_time
|
||||
self.log_event("→", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)")
|
||||
@@ -277,54 +324,59 @@ class WavFileClient:
|
||||
|
||||
async def _handle_event(self, event: dict) -> None:
|
||||
"""Handle incoming event."""
|
||||
event_type = event.get("event", "unknown")
|
||||
event_type = event.get("type", "unknown")
|
||||
ids = self._event_ids_suffix(event)
|
||||
if self.track_debug:
|
||||
print(f"[track-debug] event={event_type} trackId={event.get('trackId')}{ids}")
|
||||
|
||||
if event_type == "answer":
|
||||
self.log_event("←", "Session ready!")
|
||||
elif event_type == "speaking":
|
||||
self.log_event("←", "Speech detected")
|
||||
elif event_type == "silence":
|
||||
self.log_event("←", "Silence detected")
|
||||
elif event_type == "transcript":
|
||||
# ASR transcript (interim = asrDelta-style, final = asrFinal-style)
|
||||
if event_type == "hello.ack":
|
||||
self.log_event("←", f"Handshake acknowledged{ids}")
|
||||
elif event_type == "session.started":
|
||||
self.session_ready = True
|
||||
self.log_event("←", f"Session ready!{ids}")
|
||||
elif event_type == "config.resolved":
|
||||
config = event.get("config", {})
|
||||
self.log_event("←", f"Config resolved (output={config.get('output', {})}){ids}")
|
||||
elif event_type == "input.speech_started":
|
||||
self.log_event("←", f"Speech detected{ids}")
|
||||
elif event_type == "input.speech_stopped":
|
||||
self.log_event("←", f"Silence detected{ids}")
|
||||
elif event_type == "transcript.delta":
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
# Clear interim line and print final
|
||||
print(" " * 80, end="\r")
|
||||
self.log_event("←", f"→ You: {text}")
|
||||
else:
|
||||
# Interim result - show with indicator (overwrite same line, as in mic_client)
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" [listening] {display_text}".ljust(80), end="\r")
|
||||
elif event_type == "ttfb":
|
||||
elif event_type == "transcript.final":
|
||||
text = event.get("text", "")
|
||||
print(" " * 80, end="\r")
|
||||
self.log_event("←", f"→ You: {text}{ids}")
|
||||
elif event_type == "metrics.ttfb":
|
||||
latency_ms = event.get("latencyMs", 0)
|
||||
self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms")
|
||||
elif event_type == "llmResponse":
|
||||
elif event_type == "assistant.response.delta":
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}")
|
||||
elif self.verbose:
|
||||
# Show streaming chunks only in verbose mode
|
||||
self.log_event("←", f"LLM: {text}")
|
||||
elif event_type == "trackStart":
|
||||
if self.verbose and text:
|
||||
self.log_event("←", f"LLM: {text}{ids}")
|
||||
elif event_type == "assistant.response.final":
|
||||
text = event.get("text", "")
|
||||
if text:
|
||||
self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}")
|
||||
elif event_type == "output.audio.start":
|
||||
self.track_started = True
|
||||
self.response_start_time = time.time()
|
||||
self.waiting_for_first_audio = True
|
||||
self.log_event("←", "Bot started speaking")
|
||||
elif event_type == "trackEnd":
|
||||
self.log_event("←", f"Bot started speaking{ids}")
|
||||
elif event_type == "output.audio.end":
|
||||
self.track_ended = True
|
||||
self.log_event("←", "Bot finished speaking")
|
||||
elif event_type == "interrupt":
|
||||
self.log_event("←", "Bot interrupted!")
|
||||
self.log_event("←", f"Bot finished speaking{ids}")
|
||||
elif event_type == "response.interrupted":
|
||||
self.log_event("←", f"Bot interrupted!{ids}")
|
||||
elif event_type == "error":
|
||||
self.log_event("!", f"Error: {event.get('error')}")
|
||||
elif event_type == "hangup":
|
||||
self.log_event("←", f"Hangup: {event.get('reason')}")
|
||||
self.log_event("!", f"Error: {event.get('message')}{ids}")
|
||||
elif event_type == "session.stopped":
|
||||
self.log_event("←", f"Session stopped: {event.get('reason')}{ids}")
|
||||
self.running = False
|
||||
else:
|
||||
self.log_event("←", f"Event: {event_type}")
|
||||
self.log_event("←", f"Event: {event_type}{ids}")
|
||||
|
||||
def save_output_wav(self) -> None:
|
||||
"""Save received audio to output WAV file."""
|
||||
@@ -359,12 +411,16 @@ class WavFileClient:
|
||||
# Connect to server
|
||||
await self.connect()
|
||||
|
||||
# Wait for answer
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start receiver task
|
||||
receiver_task = asyncio.create_task(self.receiver())
|
||||
|
||||
# Wait for session.started before streaming audio
|
||||
ready_start = time.time()
|
||||
while self.running and not self.session_ready:
|
||||
if time.time() - ready_start > 8.0:
|
||||
raise TimeoutError("Timeout waiting for session.started")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Send audio
|
||||
await self.audio_sender(audio_data)
|
||||
|
||||
@@ -464,6 +520,21 @@ async def main():
|
||||
default=16000,
|
||||
help="Target sample rate for audio (default: 16000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--app-id",
|
||||
default="assistant_demo",
|
||||
help="Stable app/assistant identifier for server-side config lookup"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
default="wav_client",
|
||||
help="Client channel name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-version-id",
|
||||
default="local-dev",
|
||||
help="Optional config version identifier"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-duration",
|
||||
type=int,
|
||||
@@ -481,6 +552,17 @@ async def main():
|
||||
action="store_true",
|
||||
help="Enable verbose output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--track-debug",
|
||||
action="store_true",
|
||||
help="Print event trackId for protocol debugging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tail-silence-ms",
|
||||
type=int,
|
||||
default=800,
|
||||
help="Trailing silence to send after WAV playback for EOU detection (default: 800)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -488,10 +570,15 @@ async def main():
|
||||
url=args.url,
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
app_id=args.app_id,
|
||||
channel=args.channel,
|
||||
config_version_id=args.config_version_id,
|
||||
sample_rate=args.sample_rate,
|
||||
chunk_duration_ms=args.chunk_duration,
|
||||
wait_time=args.wait_time,
|
||||
verbose=args.verbose
|
||||
verbose=args.verbose,
|
||||
track_debug=args.track_debug,
|
||||
tail_silence_ms=args.tail_silence_ms,
|
||||
)
|
||||
|
||||
await client.run()
|
||||
|
||||
@@ -401,6 +401,9 @@
|
||||
|
||||
const targetSampleRate = 16000;
|
||||
const playbackStopRampSec = 0.008;
|
||||
const appId = "assistant_demo";
|
||||
const channel = "web_client";
|
||||
const configVersionId = "local-dev";
|
||||
|
||||
function logLine(type, text, data) {
|
||||
const time = new Date().toLocaleTimeString();
|
||||
@@ -604,15 +607,35 @@
|
||||
logLine("sys", `→ ${cmd.type}`, cmd);
|
||||
}
|
||||
|
||||
function eventIdsSuffix(event) {
|
||||
const data = event && typeof event.data === "object" && event.data ? event.data : {};
|
||||
const keys = ["turn_id", "utterance_id", "response_id", "tool_call_id", "tts_id"];
|
||||
const parts = [];
|
||||
for (const key of keys) {
|
||||
const value = data[key] || event[key];
|
||||
if (value) parts.push(`${key}=${value}`);
|
||||
}
|
||||
return parts.length ? ` [${parts.join(" ")}]` : "";
|
||||
}
|
||||
|
||||
function handleEvent(event) {
|
||||
const type = event.type || "unknown";
|
||||
logLine("event", type, event);
|
||||
const ids = eventIdsSuffix(event);
|
||||
logLine("event", `${type}${ids}`, event);
|
||||
if (type === "hello.ack") {
|
||||
sendCommand({
|
||||
type: "session.start",
|
||||
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
||||
metadata: {
|
||||
appId,
|
||||
channel,
|
||||
configVersionId,
|
||||
},
|
||||
});
|
||||
}
|
||||
if (type === "config.resolved") {
|
||||
logLine("sys", "config.resolved", event.config || {});
|
||||
}
|
||||
if (type === "transcript.final") {
|
||||
if (event.text) {
|
||||
setInterim("You", "");
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""WS v1 protocol message models and helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
def now_ms() -> int:
|
||||
@@ -11,37 +12,66 @@ def now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
class _StrictModel(BaseModel):
|
||||
"""Protocol models reject unknown fields to enforce WS v1 schema."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
# Client -> Server messages
|
||||
class HelloMessage(BaseModel):
|
||||
class HelloAuth(_StrictModel):
|
||||
apiKey: Optional[str] = None
|
||||
jwt: Optional[str] = None
|
||||
|
||||
|
||||
class HelloMessage(_StrictModel):
|
||||
type: Literal["hello"]
|
||||
version: str = Field(..., description="Protocol version, currently v1")
|
||||
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
|
||||
version: Literal["v1"] = Field(..., description="Protocol version, fixed to v1")
|
||||
auth: Optional[HelloAuth] = Field(default=None, description="Auth payload")
|
||||
|
||||
|
||||
class SessionStartMessage(BaseModel):
|
||||
class SessionStartAudio(_StrictModel):
|
||||
encoding: Literal["pcm_s16le"] = "pcm_s16le"
|
||||
sample_rate_hz: Literal[16000] = 16000
|
||||
channels: Literal[1] = 1
|
||||
|
||||
|
||||
class SessionStartMessage(_StrictModel):
|
||||
type: Literal["session.start"]
|
||||
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
|
||||
audio: Optional[SessionStartAudio] = Field(default=None, description="Optional audio format metadata")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
|
||||
|
||||
|
||||
class SessionStopMessage(BaseModel):
|
||||
class SessionStopMessage(_StrictModel):
|
||||
type: Literal["session.stop"]
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class InputTextMessage(BaseModel):
|
||||
class InputTextMessage(_StrictModel):
|
||||
type: Literal["input.text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseCancelMessage(BaseModel):
|
||||
class ResponseCancelMessage(_StrictModel):
|
||||
type: Literal["response.cancel"]
|
||||
graceful: bool = False
|
||||
|
||||
|
||||
class ToolCallResultsMessage(BaseModel):
|
||||
class ToolCallResultStatus(_StrictModel):
|
||||
code: int
|
||||
message: str
|
||||
|
||||
|
||||
class ToolCallResult(_StrictModel):
|
||||
tool_call_id: str
|
||||
name: str
|
||||
output: Any = None
|
||||
status: ToolCallResultStatus
|
||||
|
||||
|
||||
class ToolCallResultsMessage(_StrictModel):
|
||||
type: Literal["tool_call.results"]
|
||||
results: list[Dict[str, Any]] = Field(default_factory=list)
|
||||
results: list[ToolCallResult] = Field(default_factory=list)
|
||||
|
||||
|
||||
CLIENT_MESSAGE_TYPES = {
|
||||
@@ -62,7 +92,15 @@ def parse_client_message(data: Dict[str, Any]) -> BaseModel:
|
||||
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
|
||||
if not msg_class:
|
||||
raise ValueError(f"Unknown client message type: {msg_type}")
|
||||
try:
|
||||
return msg_class(**data)
|
||||
except ValidationError as exc:
|
||||
details = []
|
||||
for err in exc.errors():
|
||||
loc = ".".join(str(part) for part in err.get("loc", ()))
|
||||
msg = err.get("msg", "invalid field")
|
||||
details.append(f"{loc}: {msg}" if loc else msg)
|
||||
raise ValueError("; ".join(details)) from exc
|
||||
|
||||
|
||||
# Server -> Client event helpers
|
||||
|
||||
@@ -17,6 +17,7 @@ pydantic>=2.5.3
|
||||
pydantic-settings>=2.1.0
|
||||
python-dotenv>=1.0.0
|
||||
toml>=0.10.2
|
||||
pyyaml>=6.0.1
|
||||
|
||||
# Logging
|
||||
loguru>=0.7.2
|
||||
|
||||
@@ -7,10 +7,10 @@ for real-time voice conversation.
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
@@ -37,20 +37,21 @@ class OpenAILLMService(BaseLLMService):
|
||||
base_url: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
knowledge_config: Optional[Dict[str, Any]] = None,
|
||||
knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI LLM service.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
||||
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
||||
api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars)
|
||||
base_url: Custom API base URL (for Azure or compatible APIs)
|
||||
system_prompt: Default system prompt for conversations
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
||||
self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
@@ -60,6 +61,11 @@ class OpenAILLMService(BaseLLMService):
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||
if knowledge_searcher is None:
|
||||
adapter = build_backend_adapter_from_settings()
|
||||
self._knowledge_searcher = adapter.search_knowledge_context
|
||||
else:
|
||||
self._knowledge_searcher = knowledge_searcher
|
||||
self._tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
@@ -224,7 +230,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||
|
||||
results = await search_knowledge_context(
|
||||
results = await self._knowledge_searcher(
|
||||
kb_id=kb_id,
|
||||
query=latest_user,
|
||||
n_results=n_results,
|
||||
|
||||
@@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
@@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
@@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
|
||||
Args:
|
||||
api_key: Provider API key
|
||||
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||
model: ASR model name or alias
|
||||
sample_rate: Audio sample rate (16000 recommended)
|
||||
language: Language code (auto for automatic detection)
|
||||
@@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||
|
||||
self.api_key = api_key
|
||||
self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
|
||||
self.model = self.MODELS.get(model.lower(), model)
|
||||
self.interim_interval_ms = interim_interval_ms
|
||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||
@@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the service."""
|
||||
if not self.api_key:
|
||||
raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
@@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
)
|
||||
form_data.add_field('model', self.model)
|
||||
|
||||
async with self._session.post(self.API_URL, data=form_data) as response:
|
||||
async with self._session.post(self.api_url, data=form_data) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
text = result.get("text", "").strip()
|
||||
|
||||
@@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
@@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
Initialize OpenAI-compatible TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
|
||||
api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars)
|
||||
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
@@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.model = model
|
||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
@@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||
raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
|
||||
252
engine/tests/test_agent_config.py
Normal file
252
engine/tests/test_agent_config.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("LLM_API_KEY", "test-openai-key")
|
||||
os.environ.setdefault("TTS_API_KEY", "test-tts-key")
|
||||
os.environ.setdefault("ASR_API_KEY", "test-asr-key")
|
||||
|
||||
from app.config import load_settings
|
||||
|
||||
|
||||
def _write_yaml(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _full_agent_yaml(llm_model: str = "gpt-4o-mini", llm_key: str = "test-openai-key") -> str:
|
||||
return f"""
|
||||
agent:
|
||||
vad:
|
||||
type: silero
|
||||
model_path: data/vad/silero_vad.onnx
|
||||
threshold: 0.63
|
||||
min_speech_duration_ms: 100
|
||||
eou_threshold_ms: 800
|
||||
|
||||
llm:
|
||||
provider: openai_compatible
|
||||
model: {llm_model}
|
||||
temperature: 0.2
|
||||
api_key: {llm_key}
|
||||
api_url: https://example-llm.invalid/v1
|
||||
|
||||
tts:
|
||||
provider: openai_compatible
|
||||
api_key: test-tts-key
|
||||
api_url: https://example-tts.invalid/v1/audio/speech
|
||||
model: FunAudioLLM/CosyVoice2-0.5B
|
||||
voice: anna
|
||||
speed: 1.0
|
||||
|
||||
asr:
|
||||
provider: openai_compatible
|
||||
api_key: test-asr-key
|
||||
api_url: https://example-asr.invalid/v1/audio/transcriptions
|
||||
model: FunAudioLLM/SenseVoiceSmall
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
pre_speech_ms: 240
|
||||
final_tail_ms: 120
|
||||
|
||||
duplex:
|
||||
enabled: true
|
||||
system_prompt: You are a strict test assistant.
|
||||
|
||||
barge_in:
|
||||
min_duration_ms: 200
|
||||
silence_tolerance_ms: 60
|
||||
""".strip()
|
||||
|
||||
|
||||
def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
config_dir = tmp_path / "config" / "agents"
|
||||
_write_yaml(
|
||||
config_dir / "support.yaml",
|
||||
_full_agent_yaml(llm_model="gpt-4.1-mini"),
|
||||
)
|
||||
|
||||
settings = load_settings(
|
||||
argv=["--agent-profile", "support"],
|
||||
)
|
||||
|
||||
assert settings.llm_model == "gpt-4.1-mini"
|
||||
assert settings.llm_temperature == 0.2
|
||||
assert settings.vad_threshold == 0.63
|
||||
assert settings.agent_config_source == "cli_profile"
|
||||
assert settings.agent_config_path == str((config_dir / "support.yaml").resolve())
|
||||
|
||||
|
||||
def test_cli_path_has_higher_priority_than_env(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
env_file = tmp_path / "config" / "agents" / "env.yaml"
|
||||
cli_file = tmp_path / "config" / "agents" / "cli.yaml"
|
||||
|
||||
_write_yaml(env_file, _full_agent_yaml(llm_model="env-model"))
|
||||
_write_yaml(cli_file, _full_agent_yaml(llm_model="cli-model"))
|
||||
|
||||
monkeypatch.setenv("AGENT_CONFIG_PATH", str(env_file))
|
||||
|
||||
settings = load_settings(argv=["--agent-config", str(cli_file)])
|
||||
|
||||
assert settings.llm_model == "cli-model"
|
||||
assert settings.agent_config_source == "cli_path"
|
||||
assert settings.agent_config_path == str(cli_file.resolve())
|
||||
|
||||
|
||||
def test_default_yaml_is_loaded_without_args_or_env(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
default_file = tmp_path / "config" / "agents" / "default.yaml"
|
||||
_write_yaml(default_file, _full_agent_yaml(llm_model="from-default"))
|
||||
|
||||
monkeypatch.delenv("AGENT_CONFIG_PATH", raising=False)
|
||||
monkeypatch.delenv("AGENT_PROFILE", raising=False)
|
||||
|
||||
settings = load_settings(argv=[])
|
||||
|
||||
assert settings.llm_model == "from-default"
|
||||
assert settings.agent_config_source == "default"
|
||||
assert settings.agent_config_path == str(default_file.resolve())
|
||||
|
||||
|
||||
def test_missing_required_agent_settings_fail(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "missing-required.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
agent:
|
||||
llm:
|
||||
model: gpt-4o-mini
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_blank_required_provider_key_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "blank-key.yaml"
|
||||
_write_yaml(file_path, _full_agent_yaml(llm_key=""))
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_missing_tts_api_url_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "missing-tts-url.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
_full_agent_yaml().replace(
|
||||
" api_url: https://example-tts.invalid/v1/audio/speech\n",
|
||||
"",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_missing_asr_api_url_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "missing-asr-url.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
_full_agent_yaml().replace(
|
||||
" api_url: https://example-asr.invalid/v1/audio/transcriptions\n",
|
||||
"",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_agent_yaml_unknown_key_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "bad-agent.yaml"
|
||||
_write_yaml(file_path, _full_agent_yaml() + "\n unknown_option: true")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown agent config keys"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_legacy_siliconflow_section_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "legacy-siliconflow.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
agent:
|
||||
siliconflow:
|
||||
api_key: x
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Section 'siliconflow' is no longer supported"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_agent_yaml_missing_env_reference_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "bad-ref.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
_full_agent_yaml(llm_key="${UNSET_LLM_API_KEY}"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing environment variable"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_agent_yaml_tools_list_is_loaded(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "tools-agent.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
_full_agent_yaml()
|
||||
+ """
|
||||
|
||||
tools:
|
||||
- current_time
|
||||
- name: weather
|
||||
description: Get weather by city.
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
city:
|
||||
type: string
|
||||
required: [city]
|
||||
executor: server
|
||||
""",
|
||||
)
|
||||
|
||||
settings = load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
assert isinstance(settings.tools, list)
|
||||
assert settings.tools[0] == "current_time"
|
||||
assert settings.tools[1]["name"] == "weather"
|
||||
assert settings.tools[1]["executor"] == "server"
|
||||
|
||||
|
||||
def test_agent_yaml_tools_must_be_list(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "bad-tools-agent.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
_full_agent_yaml()
|
||||
+ """
|
||||
|
||||
tools:
|
||||
weather:
|
||||
executor: server
|
||||
""",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Agent config key 'tools' must be a list"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
150
engine/tests/test_backend_adapters.py
Normal file
150
engine/tests/test_backend_adapters.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from app.backend_adapters import (
|
||||
HistoryDisabledBackendAdapter,
|
||||
HttpBackendAdapter,
|
||||
NullBackendAdapter,
|
||||
build_backend_adapter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_backend_adapter_without_url_returns_null_adapter():
|
||||
adapter = build_backend_adapter(
|
||||
backend_url=None,
|
||||
backend_mode="auto",
|
||||
history_enabled=True,
|
||||
timeout_sec=3,
|
||||
)
|
||||
assert isinstance(adapter, NullBackendAdapter)
|
||||
|
||||
assert await adapter.fetch_assistant_config("assistant_1") is None
|
||||
assert (
|
||||
await adapter.create_call_record(
|
||||
user_id=1,
|
||||
assistant_id="assistant_1",
|
||||
source="debug",
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await adapter.add_transcript(
|
||||
call_id="call_1",
|
||||
turn_index=0,
|
||||
speaker="human",
|
||||
content="hi",
|
||||
start_ms=0,
|
||||
end_ms=100,
|
||||
confidence=0.9,
|
||||
duration_ms=100,
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
await adapter.finalize_call_record(
|
||||
call_id="call_1",
|
||||
status="connected",
|
||||
duration_seconds=2,
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert await adapter.search_knowledge_context(kb_id="kb_1", query="hello", n_results=3) == []
|
||||
assert await adapter.fetch_tool_resource("tool_1") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_backend_adapter_create_call_record_posts_expected_payload(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status=200, payload=None):
|
||||
self.status = status
|
||||
self._payload = payload if payload is not None else {}
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def json(self):
|
||||
return self._payload
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status >= 400:
|
||||
raise RuntimeError("http_error")
|
||||
|
||||
class _FakeClientSession:
|
||||
def __init__(self, timeout=None):
|
||||
captured["timeout"] = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
def post(self, url, json=None):
|
||||
captured["url"] = url
|
||||
captured["json"] = json
|
||||
return _FakeResponse(status=200, payload={"id": "call_123"})
|
||||
|
||||
monkeypatch.setattr("app.backend_adapters.aiohttp.ClientSession", _FakeClientSession)
|
||||
|
||||
adapter = build_backend_adapter(
|
||||
backend_url="http://localhost:8100",
|
||||
backend_mode="auto",
|
||||
history_enabled=True,
|
||||
timeout_sec=7,
|
||||
)
|
||||
assert isinstance(adapter, HttpBackendAdapter)
|
||||
|
||||
call_id = await adapter.create_call_record(
|
||||
user_id=99,
|
||||
assistant_id="assistant_9",
|
||||
source="debug",
|
||||
)
|
||||
|
||||
assert call_id == "call_123"
|
||||
assert captured["url"] == "http://localhost:8100/api/history"
|
||||
assert captured["json"] == {
|
||||
"user_id": 99,
|
||||
"assistant_id": "assistant_9",
|
||||
"source": "debug",
|
||||
"status": "connected",
|
||||
}
|
||||
assert isinstance(captured["timeout"], aiohttp.ClientTimeout)
|
||||
assert captured["timeout"].total == 7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_mode_disabled_forces_null_even_with_url():
|
||||
adapter = build_backend_adapter(
|
||||
backend_url="http://localhost:8100",
|
||||
backend_mode="disabled",
|
||||
history_enabled=True,
|
||||
timeout_sec=7,
|
||||
)
|
||||
assert isinstance(adapter, NullBackendAdapter)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_disabled_wraps_backend_adapter():
|
||||
adapter = build_backend_adapter(
|
||||
backend_url="http://localhost:8100",
|
||||
backend_mode="auto",
|
||||
history_enabled=False,
|
||||
timeout_sec=7,
|
||||
)
|
||||
assert isinstance(adapter, HistoryDisabledBackendAdapter)
|
||||
assert await adapter.create_call_record(user_id=1, assistant_id="a1", source="debug") is None
|
||||
assert await adapter.add_transcript(
|
||||
call_id="c1",
|
||||
turn_index=0,
|
||||
speaker="human",
|
||||
content="hi",
|
||||
start_ms=0,
|
||||
end_ms=10,
|
||||
duration_ms=10,
|
||||
) is False
|
||||
147
engine/tests/test_history_bridge.py
Normal file
147
engine/tests/test_history_bridge.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from core.history_bridge import SessionHistoryBridge
|
||||
|
||||
|
||||
class _FakeHistoryWriter:
|
||||
def __init__(self, *, add_delay_s: float = 0.0, add_result: bool = True):
|
||||
self.add_delay_s = add_delay_s
|
||||
self.add_result = add_result
|
||||
self.created_call_ids = []
|
||||
self.transcripts = []
|
||||
self.finalize_calls = 0
|
||||
self.finalize_statuses = []
|
||||
self.finalize_at = None
|
||||
self.last_transcript_at = None
|
||||
|
||||
async def create_call_record(self, *, user_id: int, assistant_id: str | None, source: str = "debug"):
|
||||
_ = (user_id, assistant_id, source)
|
||||
call_id = "call_test_1"
|
||||
self.created_call_ids.append(call_id)
|
||||
return call_id
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: float | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> bool:
|
||||
_ = confidence
|
||||
if self.add_delay_s > 0:
|
||||
await asyncio.sleep(self.add_delay_s)
|
||||
self.transcripts.append(
|
||||
{
|
||||
"call_id": call_id,
|
||||
"turn_index": turn_index,
|
||||
"speaker": speaker,
|
||||
"content": content,
|
||||
"start_ms": start_ms,
|
||||
"end_ms": end_ms,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
)
|
||||
self.last_transcript_at = time.monotonic()
|
||||
return self.add_result
|
||||
|
||||
async def finalize_call_record(self, *, call_id: str, status: str, duration_seconds: int) -> bool:
|
||||
_ = (call_id, duration_seconds)
|
||||
self.finalize_calls += 1
|
||||
self.finalize_statuses.append(status)
|
||||
self.finalize_at = time.monotonic()
|
||||
return True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_backend_does_not_block_enqueue():
|
||||
writer = _FakeHistoryWriter(add_delay_s=0.15, add_result=True)
|
||||
bridge = SessionHistoryBridge(
|
||||
history_writer=writer,
|
||||
enabled=True,
|
||||
queue_max_size=32,
|
||||
retry_max_attempts=0,
|
||||
retry_backoff_sec=0.01,
|
||||
finalize_drain_timeout_sec=1.0,
|
||||
)
|
||||
|
||||
try:
|
||||
call_id = await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug")
|
||||
assert call_id == "call_test_1"
|
||||
|
||||
t0 = time.perf_counter()
|
||||
queued = bridge.enqueue_turn(role="user", text="hello world")
|
||||
elapsed_s = time.perf_counter() - t0
|
||||
|
||||
assert queued is True
|
||||
assert elapsed_s < 0.02
|
||||
|
||||
await bridge.finalize(status="connected")
|
||||
assert len(writer.transcripts) == 1
|
||||
assert writer.finalize_calls == 1
|
||||
finally:
|
||||
await bridge.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failing_backend_retries_but_enqueue_remains_non_blocking():
|
||||
writer = _FakeHistoryWriter(add_delay_s=0.01, add_result=False)
|
||||
bridge = SessionHistoryBridge(
|
||||
history_writer=writer,
|
||||
enabled=True,
|
||||
queue_max_size=32,
|
||||
retry_max_attempts=2,
|
||||
retry_backoff_sec=0.01,
|
||||
finalize_drain_timeout_sec=0.5,
|
||||
)
|
||||
|
||||
try:
|
||||
await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug")
|
||||
t0 = time.perf_counter()
|
||||
assert bridge.enqueue_turn(role="assistant", text="retry me")
|
||||
elapsed_s = time.perf_counter() - t0
|
||||
assert elapsed_s < 0.02
|
||||
|
||||
await bridge.finalize(status="connected")
|
||||
|
||||
# Initial try + 2 retries
|
||||
assert len(writer.transcripts) == 3
|
||||
assert writer.finalize_calls == 1
|
||||
finally:
|
||||
await bridge.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalize_is_idempotent_and_waits_for_queue_drain():
|
||||
writer = _FakeHistoryWriter(add_delay_s=0.05, add_result=True)
|
||||
bridge = SessionHistoryBridge(
|
||||
history_writer=writer,
|
||||
enabled=True,
|
||||
queue_max_size=32,
|
||||
retry_max_attempts=0,
|
||||
retry_backoff_sec=0.01,
|
||||
finalize_drain_timeout_sec=1.0,
|
||||
)
|
||||
|
||||
try:
|
||||
await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug")
|
||||
assert bridge.enqueue_turn(role="user", text="first")
|
||||
|
||||
ok_1 = await bridge.finalize(status="connected")
|
||||
ok_2 = await bridge.finalize(status="connected")
|
||||
|
||||
assert ok_1 is True
|
||||
assert ok_2 is True
|
||||
assert len(writer.transcripts) == 1
|
||||
assert writer.finalize_calls == 1
|
||||
assert writer.last_transcript_at is not None
|
||||
assert writer.finalize_at is not None
|
||||
assert writer.finalize_at >= writer.last_transcript_at
|
||||
finally:
|
||||
await bridge.shutdown()
|
||||
@@ -92,16 +92,44 @@ def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tupl
|
||||
return pipeline, events
|
||||
|
||||
|
||||
def test_pipeline_uses_default_tools_from_settings(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.duplex_pipeline.settings.tools",
|
||||
[
|
||||
"current_time",
|
||||
{
|
||||
"name": "weather",
|
||||
"description": "Get weather by city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"executor": "server",
|
||||
},
|
||||
],
|
||||
)
|
||||
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||
|
||||
cfg = pipeline.resolved_runtime_config()
|
||||
assert cfg["tools"]["allowlist"] == ["current_time", "weather"]
|
||||
|
||||
schemas = pipeline._resolved_tool_schemas()
|
||||
names = [s.get("function", {}).get("name") for s in schemas if isinstance(s, dict)]
|
||||
assert "current_time" in names
|
||||
assert "weather" in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_message_parses_tool_call_results():
|
||||
msg = parse_client_message(
|
||||
{
|
||||
"type": "tool_call.results",
|
||||
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
|
||||
"results": [{"tool_call_id": "call_1", "name": "weather", "status": {"code": 200, "message": "ok"}}],
|
||||
}
|
||||
)
|
||||
assert isinstance(msg, ToolCallResultsMessage)
|
||||
assert msg.results[0]["tool_call_id"] == "call_1"
|
||||
assert msg.results[0].tool_call_id == "call_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -6,6 +6,7 @@ const getApiBaseUrl = (): string => {
|
||||
const configured = import.meta.env.VITE_API_BASE_URL || DEFAULT_API_BASE_URL;
|
||||
return trimTrailingSlash(configured);
|
||||
};
|
||||
export { getApiBaseUrl };
|
||||
|
||||
type RequestOptions = {
|
||||
method?: 'GET' | 'POST' | 'PUT' | 'DELETE';
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import { ASRModel, Assistant, CallLog, InteractionDetail, KnowledgeBase, KnowledgeDocument, LLMModel, Tool, Voice, Workflow, WorkflowEdge, WorkflowNode } from '../types';
|
||||
import { apiRequest } from './apiClient';
|
||||
import { apiRequest, getApiBaseUrl } from './apiClient';
|
||||
|
||||
type AnyRecord = Record<string, any>;
|
||||
const DEFAULT_LIST_LIMIT = 1000;
|
||||
|
||||
const withLimit = (path: string, limit: number = DEFAULT_LIST_LIMIT): string =>
|
||||
`${path}${path.includes('?') ? '&' : '?'}limit=${limit}`;
|
||||
|
||||
const readField = <T>(obj: AnyRecord, keys: string[], fallback: T): T => {
|
||||
for (const key of keys) {
|
||||
@@ -185,7 +189,8 @@ const mapKnowledgeBase = (raw: AnyRecord): KnowledgeBase => ({
|
||||
const toHistoryRow = (raw: AnyRecord, assistantNameMap: Map<string, string>): CallLog => {
|
||||
const assistantId = readField(raw, ['assistant_id', 'assistantId'], '');
|
||||
const startTime = normalizeDateLabel(readField(raw, ['started_at', 'startTime'], ''));
|
||||
const type = readField(raw, ['type'], 'text');
|
||||
const rawType = String(readField(raw, ['type'], 'text')).toLowerCase();
|
||||
const type: CallLog['type'] = rawType === 'audio' || rawType === 'video' ? rawType : 'text';
|
||||
return {
|
||||
id: String(readField(raw, ['id'], '')),
|
||||
source: readField(raw, ['source'], 'debug') as 'debug' | 'external',
|
||||
@@ -193,7 +198,7 @@ const toHistoryRow = (raw: AnyRecord, assistantNameMap: Map<string, string>): Ca
|
||||
startTime,
|
||||
duration: formatDuration(readField(raw, ['duration_seconds', 'durationSeconds'], 0)),
|
||||
agentName: assistantNameMap.get(String(assistantId)) || String(assistantId || 'Unknown Assistant'),
|
||||
type: type === 'audio' || type === 'video' ? type : 'text',
|
||||
type,
|
||||
details: [],
|
||||
};
|
||||
};
|
||||
@@ -209,7 +214,7 @@ const toHistoryDetails = (raw: AnyRecord): InteractionDetail[] => {
|
||||
};
|
||||
|
||||
export const fetchAssistants = async (): Promise<Assistant[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/assistants');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/assistants'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapAssistant(item));
|
||||
};
|
||||
@@ -276,6 +281,8 @@ export const deleteAssistant = async (id: string): Promise<void> => {
|
||||
|
||||
export interface AssistantRuntimeConfigResponse {
|
||||
assistantId: string;
|
||||
configVersionId?: string;
|
||||
assistant?: Record<string, any>;
|
||||
sessionStartMetadata: Record<string, any>;
|
||||
sources?: {
|
||||
llmModelId?: string;
|
||||
@@ -286,11 +293,11 @@ export interface AssistantRuntimeConfigResponse {
|
||||
}
|
||||
|
||||
export const fetchAssistantRuntimeConfig = async (assistantId: string): Promise<AssistantRuntimeConfigResponse> => {
|
||||
return apiRequest<AssistantRuntimeConfigResponse>(`/assistants/${assistantId}/runtime-config`);
|
||||
return apiRequest<AssistantRuntimeConfigResponse>(`/assistants/${assistantId}/config`);
|
||||
};
|
||||
|
||||
export const fetchVoices = async (): Promise<Voice[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/voices');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/voices'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapVoice(item));
|
||||
};
|
||||
@@ -351,7 +358,7 @@ export const previewVoice = async (id: string, text: string, speed?: number, api
|
||||
};
|
||||
|
||||
export const fetchASRModels = async (): Promise<ASRModel[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/asr');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/asr'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapASRModel(item));
|
||||
};
|
||||
@@ -418,8 +425,7 @@ export const previewASRModel = async (
|
||||
formData.append('api_key', options.apiKey);
|
||||
}
|
||||
|
||||
const base = (import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8100/api').replace(/\/+$/, '');
|
||||
const url = `${base}/asr/${id}/preview`;
|
||||
const url = `${getApiBaseUrl()}/asr/${id}/preview`;
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
@@ -441,7 +447,7 @@ export const previewASRModel = async (
|
||||
};
|
||||
|
||||
export const fetchLLMModels = async (): Promise<LLMModel[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/llm');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/llm'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapLLMModel(item));
|
||||
};
|
||||
@@ -501,7 +507,7 @@ export const previewLLMModel = async (
|
||||
};
|
||||
|
||||
export const fetchTools = async (): Promise<Tool[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/tools/resources');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/tools/resources'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapTool(item));
|
||||
};
|
||||
@@ -544,18 +550,14 @@ export const deleteTool = async (id: string): Promise<void> => {
|
||||
};
|
||||
|
||||
export const fetchWorkflows = async (): Promise<Workflow[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/workflows');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/workflows'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapWorkflow(item));
|
||||
};
|
||||
|
||||
export const fetchWorkflowById = async (id: string): Promise<Workflow> => {
|
||||
const list = await fetchWorkflows();
|
||||
const workflow = list.find((item) => item.id === id);
|
||||
if (!workflow) {
|
||||
throw new Error('Workflow not found');
|
||||
}
|
||||
return workflow;
|
||||
const response = await apiRequest<AnyRecord>(`/workflows/${id}`);
|
||||
return mapWorkflow(response);
|
||||
};
|
||||
|
||||
export const createWorkflow = async (data: Partial<Workflow>): Promise<Workflow> => {
|
||||
@@ -589,7 +591,7 @@ export const deleteWorkflow = async (id: string): Promise<void> => {
|
||||
};
|
||||
|
||||
export const fetchKnowledgeBases = async (): Promise<KnowledgeBase[]> => {
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/knowledge/bases');
|
||||
const response = await apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/knowledge/bases'));
|
||||
const list = Array.isArray(response) ? response : (response.list || []);
|
||||
return list.map((item) => mapKnowledgeBase(item));
|
||||
};
|
||||
@@ -641,8 +643,7 @@ export const uploadKnowledgeDocument = async (kbId: string, file: File): Promise
|
||||
formData.append('size', `${file.size} bytes`);
|
||||
formData.append('file_type', file.type || 'application/octet-stream');
|
||||
|
||||
const base = (import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8100/api').replace(/\/+$/, '');
|
||||
const url = `${base}/knowledge/bases/${kbId}/documents`;
|
||||
const url = `${getApiBaseUrl()}/knowledge/bases/${kbId}/documents`;
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
@@ -716,8 +717,8 @@ export const searchKnowledgeBase = async (
|
||||
|
||||
export const fetchHistory = async (): Promise<CallLog[]> => {
|
||||
const [historyResp, assistantsResp] = await Promise.all([
|
||||
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/history'),
|
||||
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>('/assistants'),
|
||||
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/history')),
|
||||
apiRequest<{ list?: AnyRecord[] } | AnyRecord[]>(withLimit('/assistants')),
|
||||
]);
|
||||
|
||||
const assistantList = Array.isArray(assistantsResp) ? assistantsResp : (assistantsResp.list || []);
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
],
|
||||
"skipLibCheck": true,
|
||||
"types": [
|
||||
"node"
|
||||
"node",
|
||||
"vite/client"
|
||||
],
|
||||
"moduleResolution": "bundler",
|
||||
"isolatedModules": true,
|
||||
|
||||
Reference in New Issue
Block a user