Organize config

This commit is contained in:
Xin Wang
2026-02-25 15:52:55 +08:00
parent 2b2193557d
commit 8b9064f6e6
12 changed files with 1248 additions and 92 deletions

View File

@@ -23,57 +23,21 @@ CHUNK_SIZE_MS=20
DEFAULT_CODEC=pcm DEFAULT_CODEC=pcm
MAX_AUDIO_BUFFER_SECONDS=30 MAX_AUDIO_BUFFER_SECONDS=30
# VAD / EOU # Agent profile selection (optional fallback when CLI args are not used)
VAD_TYPE=silero # Prefer CLI:
VAD_MODEL_PATH=data/vad/silero_vad.onnx # python -m app.main --agent-config config/agents/default.yaml
# Higher = stricter speech detection (fewer false positives, more misses). # python -m app.main --agent-profile default
VAD_THRESHOLD=0.5 # AGENT_CONFIG_PATH=config/agents/default.yaml
# Require this much continuous speech before utterance can be valid. # AGENT_PROFILE=default
VAD_MIN_SPEECH_DURATION_MS=100 AGENT_CONFIG_DIR=config/agents
# Silence duration required to finalize one user turn.
VAD_EOU_THRESHOLD_MS=800
# LLM # Optional: provider credentials referenced from YAML, e.g. ${LLM_API_KEY}
OPENAI_API_KEY=your_openai_api_key_here # LLM_API_KEY=your_llm_api_key_here
# Optional for OpenAI-compatible providers. # LLM_API_URL=https://api.openai.com/v1
# OPENAI_API_URL=https://api.openai.com/v1 # TTS_API_KEY=your_tts_api_key_here
LLM_MODEL=gpt-4o-mini # TTS_API_URL=https://api.example.com/v1/audio/speech
LLM_TEMPERATURE=0.7 # ASR_API_KEY=your_asr_api_key_here
# ASR_API_URL=https://api.example.com/v1/audio/transcriptions
# TTS
# edge: no API key needed
# openai_compatible: compatible with SiliconFlow-style endpoints
TTS_PROVIDER=openai_compatible
TTS_VOICE=anna
TTS_SPEED=1.0
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
# ASR
ASR_PROVIDER=openai_compatible
# Interim cadence and minimum audio before interim decode.
ASR_INTERIM_INTERVAL_MS=500
ASR_MIN_AUDIO_MS=300
# ASR start gate: ignore micro-noise, then commit to one turn once started.
ASR_START_MIN_SPEECH_MS=160
# Pre-roll protects beginning phonemes.
ASR_PRE_SPEECH_MS=240
# Tail silence protects ending phonemes.
ASR_FINAL_TAIL_MS=120
# Duplex behavior
DUPLEX_ENABLED=true
# DUPLEX_GREETING=Hello! How can I help you today?
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
# Barge-in (user interrupting assistant)
# Min user speech duration needed to interrupt assistant audio.
BARGE_IN_MIN_DURATION_MS=200
# Allowed silence during potential barge-in (ms) before reset.
BARGE_IN_SILENCE_TOLERANCE_MS=60
# Logging # Logging
LOG_LEVEL=INFO LOG_LEVEL=INFO

View File

@@ -14,6 +14,30 @@ It is currently in an early, experimental stage.
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
``` ```
使用 agent profile推荐
```
python -m app.main --agent-profile default
```
使用指定 YAML
```
python -m app.main --agent-config config/agents/default.yaml
```
Agent 配置路径优先级
1. `--agent-config`
2. `--agent-profile`(映射到 `config/agents/<profile>.yaml`
3. `AGENT_CONFIG_PATH`
4. `AGENT_PROFILE`
5. `config/agents/default.yaml`(若存在)
说明
- Agent 相关配置是严格模式YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。
- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`
- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider``api_key``api_url``model` 配置。
测试 测试
``` ```
@@ -28,4 +52,4 @@ python mic_client.py
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames. `/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`. See `docs/ws_v1_schema.md`.

View File

@@ -1,9 +1,354 @@
"""Configuration management using Pydantic settings.""" """Configuration management using Pydantic settings and agent YAML profiles."""
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional
from pydantic import Field from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
import json
try:
import yaml
except ImportError: # pragma: no cover - validated when agent YAML is used
yaml = None
_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}")
_DEFAULT_AGENT_CONFIG_DIR = "config/agents"
_DEFAULT_AGENT_CONFIG_FILE = "default.yaml"
_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
"vad": {
"type": "vad_type",
"model_path": "vad_model_path",
"threshold": "vad_threshold",
"min_speech_duration_ms": "vad_min_speech_duration_ms",
"eou_threshold_ms": "vad_eou_threshold_ms",
},
"llm": {
"provider": "llm_provider",
"model": "llm_model",
"temperature": "llm_temperature",
"api_key": "llm_api_key",
"api_url": "llm_api_url",
},
"tts": {
"provider": "tts_provider",
"api_key": "tts_api_key",
"api_url": "tts_api_url",
"model": "tts_model",
"voice": "tts_voice",
"speed": "tts_speed",
},
"asr": {
"provider": "asr_provider",
"api_key": "asr_api_key",
"api_url": "asr_api_url",
"model": "asr_model",
"interim_interval_ms": "asr_interim_interval_ms",
"min_audio_ms": "asr_min_audio_ms",
"start_min_speech_ms": "asr_start_min_speech_ms",
"pre_speech_ms": "asr_pre_speech_ms",
"final_tail_ms": "asr_final_tail_ms",
},
"duplex": {
"enabled": "duplex_enabled",
"greeting": "duplex_greeting",
"system_prompt": "duplex_system_prompt",
},
"barge_in": {
"min_duration_ms": "barge_in_min_duration_ms",
"silence_tolerance_ms": "barge_in_silence_tolerance_ms",
},
}
_AGENT_SETTING_KEYS = {
"vad_type",
"vad_model_path",
"vad_threshold",
"vad_min_speech_duration_ms",
"vad_eou_threshold_ms",
"llm_provider",
"llm_api_key",
"llm_api_url",
"llm_model",
"llm_temperature",
"tts_provider",
"tts_api_key",
"tts_api_url",
"tts_model",
"tts_voice",
"tts_speed",
"asr_provider",
"asr_api_key",
"asr_api_url",
"asr_model",
"asr_interim_interval_ms",
"asr_min_audio_ms",
"asr_start_min_speech_ms",
"asr_pre_speech_ms",
"asr_final_tail_ms",
"duplex_enabled",
"duplex_greeting",
"duplex_system_prompt",
"barge_in_min_duration_ms",
"barge_in_silence_tolerance_ms",
}
_BASE_REQUIRED_AGENT_SETTING_KEYS = {
"vad_type",
"vad_model_path",
"vad_threshold",
"vad_min_speech_duration_ms",
"vad_eou_threshold_ms",
"llm_provider",
"llm_model",
"llm_temperature",
"tts_provider",
"tts_voice",
"tts_speed",
"asr_provider",
"asr_interim_interval_ms",
"asr_min_audio_ms",
"asr_start_min_speech_ms",
"asr_pre_speech_ms",
"asr_final_tail_ms",
"duplex_enabled",
"duplex_system_prompt",
"barge_in_min_duration_ms",
"barge_in_silence_tolerance_ms",
}
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
return str(overrides.get(key) or default).strip().lower()
def _is_blank(value: Any) -> bool:
return value is None or (isinstance(value, str) and not value.strip())
@dataclass(frozen=True)
class AgentConfigSelection:
"""Resolved agent config location and how it was selected."""
path: Optional[Path]
source: str
def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]:
"""Parse only agent-related CLI flags from argv."""
config_path: Optional[str] = None
profile: Optional[str] = None
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith("--agent-config="):
config_path = arg.split("=", 1)[1].strip() or None
elif arg == "--agent-config" and i + 1 < len(argv):
config_path = argv[i + 1].strip() or None
i += 1
elif arg.startswith("--agent-profile="):
profile = arg.split("=", 1)[1].strip() or None
elif arg == "--agent-profile" and i + 1 < len(argv):
profile = argv[i + 1].strip() or None
i += 1
i += 1
return config_path, profile
def _agent_config_dir() -> Path:
base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR))
if not base_dir.is_absolute():
base_dir = Path.cwd() / base_dir
return base_dir.resolve()
def _resolve_agent_selection(
agent_config_path: Optional[str] = None,
agent_profile: Optional[str] = None,
argv: Optional[List[str]] = None,
) -> AgentConfigSelection:
cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:]))
path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH")
profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE")
source = "none"
candidate: Optional[Path] = None
if path_value:
source = "cli_path" if (agent_config_path or cli_path) else "env_path"
candidate = Path(path_value)
elif profile_value:
source = "cli_profile" if (agent_profile or cli_profile) else "env_profile"
candidate = _agent_config_dir() / f"{profile_value}.yaml"
else:
fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE
if fallback.exists():
source = "default"
candidate = fallback
if candidate is None:
raise ValueError(
"Agent YAML config is required. Provide --agent-config/--agent-profile "
"or create config/agents/default.yaml."
)
if not candidate.is_absolute():
candidate = (Path.cwd() / candidate).resolve()
else:
candidate = candidate.resolve()
if not candidate.exists():
raise ValueError(f"Agent config file not found ({source}): {candidate}")
if not candidate.is_file():
raise ValueError(f"Agent config path is not a file: {candidate}")
return AgentConfigSelection(path=candidate, source=source)
def _resolve_env_refs(value: Any) -> Any:
"""Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively."""
if isinstance(value, dict):
return {k: _resolve_env_refs(v) for k, v in value.items()}
if isinstance(value, list):
return [_resolve_env_refs(item) for item in value]
if not isinstance(value, str) or "${" not in value:
return value
def _replace(match: re.Match[str]) -> str:
env_key = match.group(1)
default_value = match.group(2)
env_value = os.getenv(env_key)
if env_value is None:
if default_value is None:
raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}")
return default_value
return env_value
return _ENV_REF_PATTERN.sub(_replace, value)
def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize YAML into flat Settings fields."""
normalized: Dict[str, Any] = {}
for key, value in raw.items():
if key == "siliconflow":
raise ValueError(
"Section 'siliconflow' is no longer supported. "
"Move provider-specific fields into agent.llm / agent.asr / agent.tts."
)
section_map = _AGENT_SECTION_KEY_MAP.get(key)
if section_map is None:
normalized[key] = value
continue
if not isinstance(value, dict):
raise ValueError(f"Agent config section '{key}' must be a mapping")
for nested_key, nested_value in value.items():
mapped_key = section_map.get(nested_key)
if mapped_key is None:
raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'")
normalized[mapped_key] = nested_value
unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS)
if unknown_keys:
raise ValueError(
"Unknown agent config keys in YAML: "
+ ", ".join(unknown_keys)
)
return normalized
def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides))
string_required = {
"vad_type",
"vad_model_path",
"llm_provider",
"llm_model",
"tts_provider",
"tts_voice",
"asr_provider",
"duplex_system_prompt",
}
for key in string_required:
if key in overrides and _is_blank(overrides.get(key)):
missing.add(key)
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
missing.add("llm_api_key")
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
missing.add("tts_api_key")
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
missing.add("tts_api_url")
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
missing.add("tts_model")
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
missing.add("asr_api_key")
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
missing.add("asr_api_url")
if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")):
missing.add("asr_model")
return sorted(missing)
def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]:
if yaml is None:
raise RuntimeError(
"PyYAML is required for agent YAML configuration. Install with: pip install pyyaml"
)
with selection.path.open("r", encoding="utf-8") as file:
raw = yaml.safe_load(file) or {}
if not isinstance(raw, dict):
raise ValueError(f"Agent config must be a YAML mapping: {selection.path}")
if "agent" in raw:
agent_value = raw["agent"]
if not isinstance(agent_value, dict):
raise ValueError("The 'agent' key in YAML must be a mapping")
raw = agent_value
resolved = _resolve_env_refs(raw)
overrides = _normalize_agent_overrides(resolved)
missing_required = _missing_required_keys(overrides)
if missing_required:
raise ValueError(
f"Missing required agent settings in YAML ({selection.path}): "
+ ", ".join(missing_required)
)
overrides["agent_config_path"] = str(selection.path)
overrides["agent_config_source"] = selection.source
return overrides
def load_settings(
agent_config_path: Optional[str] = None,
agent_profile: Optional[str] = None,
argv: Optional[List[str]] = None,
) -> "Settings":
"""Load settings from .env and optional agent YAML."""
selection = _resolve_agent_selection(
agent_config_path=agent_config_path,
agent_profile=agent_profile,
argv=argv,
)
agent_overrides = _load_agent_overrides(selection)
return Settings(**agent_overrides)
class Settings(BaseSettings): class Settings(BaseSettings):
@@ -37,30 +382,35 @@ class Settings(BaseSettings):
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds") vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds") vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
# OpenAI / LLM Configuration # LLM Configuration
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key") llm_provider: str = Field(
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)") default="openai",
description="LLM provider (openai, openai_compatible, siliconflow)"
)
llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key")
llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL")
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name") llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation") llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
# TTS Configuration # TTS Configuration
tts_provider: str = Field( tts_provider: str = Field(
default="openai_compatible", default="openai_compatible",
description="TTS provider (edge, openai_compatible; siliconflow alias supported)" description="TTS provider (edge, openai_compatible, siliconflow)"
) )
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
tts_model: Optional[str] = Field(default=None, description="TTS model name")
tts_voice: str = Field(default="anna", description="TTS voice name") tts_voice: str = Field(default="anna", description="TTS voice name")
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier") tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
# SiliconFlow Configuration
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
# ASR Configuration # ASR Configuration
asr_provider: str = Field( asr_provider: str = Field(
default="openai_compatible", default="openai_compatible",
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)" description="ASR provider (openai_compatible, buffered, siliconflow)"
) )
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model") asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key")
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
asr_start_min_speech_ms: int = Field( asr_start_min_speech_ms: int = Field(
@@ -122,6 +472,10 @@ class Settings(BaseSettings):
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds") backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records") history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
# Agent YAML metadata
agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path")
agent_config_source: str = Field(default="none", description="How the agent YAML was selected")
@property @property
def chunk_size_bytes(self) -> int: def chunk_size_bytes(self) -> int:
"""Calculate chunk size in bytes based on sample rate and duration.""" """Calculate chunk size in bytes based on sample rate and duration."""
@@ -146,7 +500,7 @@ class Settings(BaseSettings):
# Global settings instance # Global settings instance
settings = Settings() settings = load_settings()
def get_settings() -> Settings: def get_settings() -> Settings:

View File

@@ -357,6 +357,12 @@ async def startup_event():
logger.info(f"Server: {settings.host}:{settings.port}") logger.info(f"Server: {settings.host}:{settings.port}")
logger.info(f"Sample rate: {settings.sample_rate} Hz") logger.info(f"Sample rate: {settings.sample_rate} Hz")
logger.info(f"VAD model: {settings.vad_model_path}") logger.info(f"VAD model: {settings.vad_model_path}")
if settings.agent_config_path:
logger.info(
f"Agent config loaded ({settings.agent_config_source}): {settings.agent_config_path}"
)
else:
logger.info("Agent config: none (using .env/default agent values)")
@app.on_event("shutdown") @app.on_event("shutdown")

View File

@@ -0,0 +1,50 @@
# Agent behavior configuration (safe to edit per profile)
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
# Infra/server/network settings should stay in .env.
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.5
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
# provider: openai | openai_compatible | siliconflow
provider: openai_compatible
model: deepseek-v3
temperature: 0.7
# Required: no fallback. You can still reference env explicitly.
api_key: your_llm_api_key
# Optional for OpenAI-compatible endpoints:
api_url: https://api.qnaigc.com/v1
tts:
# provider: edge | openai_compatible | siliconflow
provider: openai_compatible
api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech
model: FunAudioLLM/CosyVoice2-0.5B
voice: anna
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
provider: openai_compatible
api_key: you_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60

View File

@@ -310,7 +310,12 @@ class DuplexPipeline:
def resolved_runtime_config(self) -> Dict[str, Any]: def resolved_runtime_config(self) -> Dict[str, Any]:
"""Return current effective runtime configuration without secrets.""" """Return current effective runtime configuration without secrets."""
llm_provider = str(self._runtime_llm.get("provider") or "openai").lower() llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
llm_base_url = (
self._runtime_llm.get("baseUrl")
or settings.llm_api_url
or self._default_llm_base_url(llm_provider)
)
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower() tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower() asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
output_mode = str(self._runtime_output.get("mode") or "").strip().lower() output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
@@ -323,18 +328,18 @@ class DuplexPipeline:
"llm": { "llm": {
"provider": llm_provider, "provider": llm_provider,
"model": str(self._runtime_llm.get("model") or settings.llm_model), "model": str(self._runtime_llm.get("model") or settings.llm_model),
"baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url, "baseUrl": llm_base_url,
}, },
"asr": { "asr": {
"provider": asr_provider, "provider": asr_provider,
"model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model), "model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
}, },
"tts": { "tts": {
"enabled": self._tts_output_enabled(), "enabled": self._tts_output_enabled(),
"provider": tts_provider, "provider": tts_provider,
"model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model), "model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice), "voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed), "speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
}, },
@@ -452,6 +457,18 @@ class DuplexPipeline:
normalized = str(provider or "").strip().lower() normalized = str(provider or "").strip().lower()
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"} return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
@staticmethod
def _is_llm_provider_supported(provider: Any) -> bool:
normalized = str(provider or "").strip().lower()
return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"}
@staticmethod
def _default_llm_base_url(provider: Any) -> Optional[str]:
normalized = str(provider or "").strip().lower()
if normalized == "siliconflow":
return "https://api.siliconflow.cn/v1"
return None
def _tts_output_enabled(self) -> bool: def _tts_output_enabled(self) -> bool:
enabled = self._coerce_bool(self._runtime_tts.get("enabled")) enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
if enabled is not None: if enabled is not None:
@@ -527,12 +544,16 @@ class DuplexPipeline:
try: try:
# Connect LLM service # Connect LLM service
if not self.llm_service: if not self.llm_service:
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key
llm_base_url = (
self._runtime_llm.get("baseUrl")
or settings.llm_api_url
or self._default_llm_base_url(llm_provider)
)
llm_model = self._runtime_llm.get("model") or settings.llm_model llm_model = self._runtime_llm.get("model") or settings.llm_model
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
if llm_provider == "openai" and llm_api_key: if self._is_llm_provider_supported(llm_provider) and llm_api_key:
self.llm_service = OpenAILLMService( self.llm_service = OpenAILLMService(
api_key=llm_api_key, api_key=llm_api_key,
base_url=llm_base_url, base_url=llm_base_url,
@@ -540,7 +561,7 @@ class DuplexPipeline:
knowledge_config=self._resolved_knowledge_config(), knowledge_config=self._resolved_knowledge_config(),
) )
else: else:
logger.warning("No OpenAI API key - using mock LLM") logger.warning("LLM provider unsupported or API key missing - using mock LLM")
self.llm_service = MockLLMService() self.llm_service = MockLLMService()
if hasattr(self.llm_service, "set_knowledge_config"): if hasattr(self.llm_service, "set_knowledge_config"):
@@ -556,20 +577,22 @@ class DuplexPipeline:
if tts_output_enabled: if tts_output_enabled:
if not self.tts_service: if not self.tts_service:
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model tts_model = self._runtime_tts.get("model") or settings.tts_model
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
if self._is_openai_compatible_provider(tts_provider) and tts_api_key: if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
self.tts_service = OpenAICompatibleTTSService( self.tts_service = OpenAICompatibleTTSService(
api_key=tts_api_key, api_key=tts_api_key,
api_url=tts_api_url,
voice=tts_voice, voice=tts_voice,
model=tts_model, model=tts_model or "FunAudioLLM/CosyVoice2-0.5B",
sample_rate=settings.sample_rate, sample_rate=settings.sample_rate,
speed=tts_speed speed=tts_speed
) )
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)") logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})")
else: else:
self.tts_service = EdgeTTSService( self.tts_service = EdgeTTSService(
voice=tts_voice, voice=tts_voice,
@@ -592,21 +615,23 @@ class DuplexPipeline:
# Connect ASR service # Connect ASR service
if not self.asr_service: if not self.asr_service:
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower() asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
if self._is_openai_compatible_provider(asr_provider) and asr_api_key: if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
self.asr_service = OpenAICompatibleASRService( self.asr_service = OpenAICompatibleASRService(
api_key=asr_api_key, api_key=asr_api_key,
model=asr_model, api_url=asr_api_url,
model=asr_model or "FunAudioLLM/SenseVoiceSmall",
sample_rate=settings.sample_rate, sample_rate=settings.sample_rate,
interim_interval_ms=asr_interim_interval, interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms, min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback on_transcript=self._on_transcript_callback
) )
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)") logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})")
else: else:
self.asr_service = BufferedASRService( self.asr_service = BufferedASRService(
sample_rate=settings.sample_rate sample_rate=settings.sample_rate

520
docs/ws_v1_schema_zh.md Normal file
View File

@@ -0,0 +1,520 @@
# WS v1 协议完整说明(中文)
本文档描述 `/ws` 端点的 WebSocket v1 协议,覆盖:
- 客户端输入JSON 文本消息 + 二进制音频);
- 服务端输出JSON 事件 + 二进制音频);
- 每个参数的类型、约束、含义与使用方式;
- 握手顺序、状态机、错误语义与实现细节。
实现对照来源:
- `models/ws_v1.py`
- `core/session.py`
- `core/duplex_pipeline.py`
- `app/main.py`
---
## 1. 传输与基础规则
- 连接地址:`ws://<host>/ws`
- 单连接双通道承载:
- 文本帧JSON 控制消息(严格校验 schema
- 二进制帧:原始 PCM 音频
- JSON 校验策略:
- 所有已定义客户端消息都 `extra="forbid"`,即不允许未声明字段;
- `hello.version` 固定必须是 `"v1"`
- 缺失 `type` 或未知 `type` 会返回协议错误。
---
## 2. 状态机与消息顺序
### 2.1 服务端状态
- `WAIT_HELLO`:等待 `hello`
- `WAIT_START`:已通过握手,等待 `session.start`
- `ACTIVE`:会话运行中,可收发文本/音频
- `STOPPED`:会话结束
### 2.2 正确顺序
1. 客户端发送 `hello`
2. 服务端返回 `hello.ack`
3. 客户端发送 `session.start`
4. 服务端返回 `session.started`
5. 客户端可持续发送:
- 二进制音频
- `input.text`(可选)
- `response.cancel`(可选)
- `tool_call.results`(可选)
6. 客户端发送 `session.stop` 或直接断开连接
顺序错误会返回 `error``code = "protocol.order"`
---
## 3. 客户端 -> 服务端消息(输入)
## 3.1 `hello`
示例:
```json
{
"type": "hello",
"version": "v1",
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 |
| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 |
| `auth` | object \| null | 否 | 仅允许 `apiKey``jwt` | 认证载荷 | 认证策略由服务端配置决定 |
| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 |
| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 |
认证行为:
- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。
-`WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY``auth.apiKey``auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。
## 3.2 `session.start`
示例:
```json
{
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": 16000,
"channels": 1
},
"metadata": {
"appId": "assistant_123",
"channel": "web",
"configVersionId": "cfg_20260217_01",
"client": "web-debug",
"output": {
"mode": "audio"
},
"systemPrompt": "你是简洁助手",
"greeting": "你好,我能帮你什么?"
}
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"session.start"` | 启动会话 | 握手后第二阶段消息 |
| `audio` | object \| null | 否 | 仅支持固定值 | 音频格式描述 | 仅用于声明MVP 实际只接受固定 PCM |
| `audio.encoding` | string | 否 | 固定 `"pcm_s16le"` | 编码格式 | 非该值会在模型校验层报错 |
| `audio.sample_rate_hz` | number | 否 | 固定 `16000` | 采样率 | 16kHz |
| `audio.channels` | number | 否 | 固定 `1` | 声道数 | 单声道 |
| `metadata` | object \| null | 否 | 任意对象(会被白名单过滤) | 运行时配置 | 用于 app/channel/提示词/输出模式等覆盖 |
`metadata` 白名单策略(关键):
- 允许透传的标识字段ID 类):
- `appId` / `app_id`
- `channel`
- `configVersionId` / `config_version_id`
- 允许透传的覆盖字段:
- `firstTurnMode`
- `greeting`
- `generatedOpenerEnabled`
- `systemPrompt`
- `output`
- `bargeIn`
- `knowledge`
- `knowledgeBaseId`
- `history`
- `userId`
- `assistantId`
- `source`
- 客户端传入 `metadata.services` 会被忽略(服务端会记录 warning服务配置由后端/环境变量决定。
`output.mode` 用法:
- `"audio"`(默认语音输出)
- `"text"`(纯文本输出)
- 纯文本模式下仍会收到 `assistant.response.delta/final`
- 不会收到 TTS 音频帧与 `output.audio.start/end`
## 3.3 `input.text`
示例:
```json
{
"type": "input.text",
"text": "你能做什么?"
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"input.text"` | 文本输入 | 跳过 ASR直接触发 LLM 回答 |
| `text` | string | 是 | 非空字符串为佳 | 用户文本 | 用于文本聊天或调试 |
## 3.4 `response.cancel`
示例:
```json
{
"type": "response.cancel",
"graceful": false
}
```
字段说明:
| 字段 | 类型 | 必填 | 默认值 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 |
| `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 |
## 3.5 `tool_call.results`
仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。
示例:
```json
{
"type": "tool_call.results",
"results": [
{
"tool_call_id": "call_abc123",
"name": "weather",
"output": { "temp_c": 21, "condition": "sunny" },
"status": { "code": 200, "message": "ok" }
}
]
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"tool_call.results"` | 工具执行回传 | 客户端工具结果上送 |
| `results` | array | 否 | 默认为空数组 | 多个工具结果 | 可批量回传 |
| `results[].tool_call_id` | string | 是 | 任意字符串 | 工具调用ID | 必须与 `assistant.tool_call.tool_call_id` 对应 |
| `results[].name` | string | 是 | 任意字符串 | 工具名 | 建议与请求一致 |
| `results[].output` | any | 否 | 任意 JSON | 工具输出 | 供模型后续组织回答 |
| `results[].status` | object | 是 | 包含 `code``message` | 执行状态 | 用于判定成功/失败 |
| `results[].status.code` | number | 是 | HTTP 风格状态码 | 状态码 | `200-299` 判定成功 |
| `results[].status.message` | string | 是 | 任意字符串 | 状态描述 | 例如 `"ok"` / `"timeout"` |
处理规则:
- 未请求过的 `tool_call_id` 会被忽略(防止伪造/串话);
- 重复回传会被忽略;
- 超时未回传会由服务端合成超时结果(`504`)。
## 3.6 `session.stop`
示例:
```json
{
"type": "session.stop",
"reason": "client_disconnect"
}
```
字段说明:
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|---|---|---|---|---|---|
| `type` | string | 是 | 固定 `"session.stop"` | 结束会话 | 正常结束推荐发送 |
| `reason` | string \| null | 否 | 任意字符串 | 结束原因 | 服务端会回传到 `session.stopped.reason` |
---
## 4. 二进制音频输入(客户端 -> 服务端)
`session.started` 之后可持续发送二进制音频。
固定格式MVP
- 编码:`pcm_s16le`
- 采样率:`16000`
- 声道:`1`
- 帧长20ms = `640 bytes`
分包规则:
- 单个 WebSocket 二进制消息可包含 1 帧或多帧;
- 长度必须是 `640` 的整数倍;
- 不是 `640` 倍数会触发 `audio.frame_size_mismatch`,该消息整包丢弃;
- 奇数字节长度会触发 `audio.invalid_pcm`
---
## 5. 服务端 -> 客户端事件(输出)
所有 JSON 事件都包含统一包络字段。
## 5.1 统一包络Envelope
```json
{
"type": "event.name",
"timestamp": 1730000000000,
"sessionId": "sess_xxx",
"seq": 42,
"source": "asr",
"trackId": "audio_in",
"data": {}
}
```
字段说明:
| 字段 | 类型 | 含义 | 使用说明 |
|---|---|---|---|
| `type` | string | 事件类型 | 见下方事件清单 |
| `timestamp` | number | 事件时间戳(毫秒) | 由 `ev()` 生成 |
| `sessionId` | string | 会话ID | 同一连接固定 |
| `seq` | number | 单会话递增序号 | 可用于重放、去重、排序 |
| `source` | string | 事件来源 | 常见:`asr`/`llm`/`tts`/`tool`/`system`/`client`/`server` |
| `trackId` | string | 事件轨道 | 常用:`audio_in`/`audio_out`/`control` |
| `data` | object | 结构化数据 | 顶层业务字段会镜像进 `data` 以兼容旧客户端 |
关联ID`data` 内自动注入,存在时):
- `turn_id`:一次用户-助手对话轮次
- `utterance_id`:一次用户语音话语
- `response_id`:一次助手生成响应
- `tool_call_id`:一次工具调用
- `tts_id`:一次 TTS 播放段
## 5.2 事件类型与参数
### 5.2.1 会话与控制类
1. `hello.ack`
- 关键字段:`version`
- 含义:握手成功,应紧接着发送 `session.start`
2. `session.started`
- 关键字段:
- `trackId`
- `tracks.audio_in`
- `tracks.audio_out`
- `tracks.control`
- `audio`(回显客户端声明的音频元信息)
- 含义:会话进入 ACTIVE可发音频/文本
3. `config.resolved`
- 关键字段:
- `config.appId`
- `config.channel`
- `config.configVersionId`
- `config.prompt.sha256`
- `config.output`
- `config.services`(去密钥后的有效服务配置)
- `config.tools.allowlist`
- `config.tracks`
- 含义:服务端最终生效配置快照,便于前端展示与排错
4. `heartbeat`
- 关键字段:无业务字段(仅 envelope
- 含义:保活心跳
- 默认间隔:`heartbeat_interval_sec`(默认 50s
5. `session.stopped`
- 关键字段:`reason`
- 含义:会话结束确认
6. `error`
- 关键字段:
- `sender`
- `code`
- `message`
- `stage`
- `retryable`
- `trackId`
- `data.error`(结构化错误镜像)
- 含义:统一错误事件
### 5.2.2 识别与输入侧ASR/VAD
1. `input.speech_started`
- 字段:`probability`
- 含义:检测到语音开始
2. `input.speech_stopped`
- 字段:`probability`
- 含义:检测到语音结束
3. `transcript.delta`
- 字段:`text`
- 含义ASR 增量识别文本(节流发送)
4. `transcript.final`
- 字段:`text`
- 含义ASR 最终识别文本
### 5.2.3 输出侧LLM/TTS/Tool
1. `assistant.response.delta`
- 字段:`text`
- 含义:助手增量文本输出(节流发送)
2. `assistant.response.final`
- 字段:`text`
- 含义:助手完整文本输出
3. `assistant.tool_call`
- 字段:
- `tool_call_id`
- `tool_name`
- `arguments`(对象)
- `executor``client``server`
- `timeout_ms`
- `tool_call`(完整工具调用对象)
- 含义:通知客户端发生工具调用(用于可视化或客户端执行)
4. `assistant.tool_result`
- 字段:
- `source``client``server`
- `tool_call_id`
- `tool_name`
- `ok`boolean
- `error`(失败时 `{code,message,retryable}`
- `result`(原始结果对象)
- 含义:工具调用结果回执
5. `output.audio.start`
- 含义TTS 音频输出开始边界
6. `output.audio.end`
- 含义TTS 音频输出结束边界
7. `response.interrupted`
- 含义当前回答被打断barge-in 或 cancel
8. `metrics.ttfb`
- 字段:`latencyMs`
- 含义首包音频时延TTFB
### 5.2.4 工作流扩展事件(可选)
`metadata.workflow` 生效,会额外出现:
- `workflow.started`
- `workflow.node.entered`
- `workflow.edge.taken`
- `workflow.tool.requested`
- `workflow.human_transfer`
- `workflow.ended`
这些事件用于外部可视化工作流状态,不影响基础语音会话协议。
---
## 6. 服务端二进制音频输出(服务端 -> 客户端)
- 音频为 PCM 二进制帧;
- 发送单位对齐到 `640 bytes`(不足会补零后发送);
- 前端通常结合 `output.audio.start/end` 做播放边界控制;
- 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。
---
## 7. 错误模型与常见错误码
统一结构(`error` 事件):
```json
{
"type": "error",
"sender": "client",
"code": "protocol.invalid_message",
"message": "Invalid message: ...",
"stage": "protocol",
"retryable": false,
"trackId": "control",
"data": {
"error": {
"stage": "protocol",
"code": "protocol.invalid_message",
"message": "Invalid message: ...",
"retryable": false
}
}
}
```
字段语义:
- `sender`:错误来源角色(如 `client` / `server` / `auth`
- `code`:机器可读错误码
- `message`:人类可读描述
- `stage`:阶段(`protocol|audio|asr|llm|tts|tool`
- `retryable`:是否建议重试
- `trackId`:错误归属轨道
常见错误码:
- `protocol.invalid_json`
- `protocol.invalid_message`
- `protocol.order`
- `protocol.version_unsupported`
- `protocol.unsupported`
- `auth.invalid_api_key`
- `auth.required`
- `audio.invalid_pcm`
- `audio.frame_size_mismatch`
- `audio.processing_failed`
- `server.internal`
---
## 8. 心跳与超时
服务端后台任务逻辑:
- 每隔约 5 秒检查一次连接;
- 超过 `inactivity_timeout_sec`(默认 60 秒)未收到任何客户端消息则关闭会话;
- 每隔 `heartbeat_interval_sec`(默认 50 秒)发送一次 `heartbeat`
客户端建议:
- 持续上行音频或定期发送轻量文本消息,避免被判定闲置;
-`heartbeat` + `seq` 检测连接活性和事件乱序。
---
## 9. 实战接入建议
1. 建连后立即发送 `hello`,收到 `hello.ack` 后再发 `session.start`
2. 语音输入严格按 16k/16bit/mono并保证每个 WS 二进制消息长度是 `640*n`
3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。
4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。
5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`
6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`
---
## 10. 最小完整时序示例
```text
Client -> hello
Server <- hello.ack
Client -> session.start
Server <- session.started
Server <- config.resolved
Client -> (binary pcm frames...)
Server <- input.speech_started / transcript.delta / transcript.final
Server <- assistant.response.delta / assistant.response.final
Server <- output.audio.start
Server <- (binary pcm frames...)
Server <- output.audio.end
Client -> session.stop
Server <- session.stopped
```

View File

@@ -17,6 +17,7 @@ pydantic>=2.5.3
pydantic-settings>=2.1.0 pydantic-settings>=2.1.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
toml>=0.10.2 toml>=0.10.2
pyyaml>=6.0.1
# Logging # Logging
loguru>=0.7.2 loguru>=0.7.2

View File

@@ -43,14 +43,14 @@ class OpenAILLMService(BaseLLMService):
Args: Args:
model: Model name (e.g., "gpt-4o-mini", "gpt-4o") model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars)
base_url: Custom API base URL (for Azure or compatible APIs) base_url: Custom API base URL (for Azure or compatible APIs)
system_prompt: Default system prompt for conversations system_prompt: Default system prompt for conversations
""" """
super().__init__(model=model) super().__init__(model=model)
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY")
self.base_url = base_url or os.getenv("OPENAI_API_URL") self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
self.system_prompt = system_prompt or ( self.system_prompt = system_prompt or (
"You are a helpful, friendly voice assistant. " "You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational. " "Keep your responses concise and conversational. "

View File

@@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti
import asyncio import asyncio
import io import io
import os
import wave import wave
from typing import AsyncIterator, Optional, Callable, Awaitable from typing import AsyncIterator, Optional, Callable, Awaitable
from loguru import logger from loguru import logger
@@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService):
def __init__( def __init__(
self, self,
api_key: str, api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: str = "FunAudioLLM/SenseVoiceSmall", model: str = "FunAudioLLM/SenseVoiceSmall",
sample_rate: int = 16000, sample_rate: int = 16000,
language: str = "auto", language: str = "auto",
@@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService):
Args: Args:
api_key: Provider API key api_key: Provider API key
api_url: Provider API URL (defaults to SiliconFlow endpoint)
model: ASR model name or alias model: ASR model name or alias
sample_rate: Audio sample rate (16000 recommended) sample_rate: Audio sample rate (16000 recommended)
language: Language code (auto for automatic detection) language: Language code (auto for automatic detection)
@@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService):
if not AIOHTTP_AVAILABLE: if not AIOHTTP_AVAILABLE:
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService") raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
self.api_key = api_key self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
self.model = self.MODELS.get(model.lower(), model) self.model = self.MODELS.get(model.lower(), model)
self.interim_interval_ms = interim_interval_ms self.interim_interval_ms = interim_interval_ms
self.min_audio_for_interim_ms = min_audio_for_interim_ms self.min_audio_for_interim_ms = min_audio_for_interim_ms
@@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService):
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to the service.""" """Connect to the service."""
if not self.api_key:
raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.")
self._session = aiohttp.ClientSession( self._session = aiohttp.ClientSession(
headers={ headers={
"Authorization": f"Bearer {self.api_key}" "Authorization": f"Bearer {self.api_key}"
@@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService):
) )
form_data.add_field('model', self.model) form_data.add_field('model', self.model)
async with self._session.post(self.API_URL, data=form_data) as response: async with self._session.post(self.api_url, data=form_data) as response:
if response.status == 200: if response.status == 200:
result = await response.json() result = await response.json()
text = result.get("text", "").strip() text = result.get("text", "").strip()

View File

@@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_url: Optional[str] = None,
voice: str = "anna", voice: str = "anna",
model: str = "FunAudioLLM/CosyVoice2-0.5B", model: str = "FunAudioLLM/CosyVoice2-0.5B",
sample_rate: int = 16000, sample_rate: int = 16000,
@@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService):
Initialize OpenAI-compatible TTS service. Initialize OpenAI-compatible TTS service.
Args: Args:
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var) api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars)
api_url: Provider API URL (defaults to SiliconFlow endpoint)
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
model: Model name model: Model name
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
@@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService):
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed) super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY") self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
self.model = model self.model = model
self.api_url = "https://api.siliconflow.cn/v1/audio/speech" self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech"
self._session: Optional[aiohttp.ClientSession] = None self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event() self._cancel_event = asyncio.Event()
@@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
async def connect(self) -> None: async def connect(self) -> None:
"""Initialize HTTP session.""" """Initialize HTTP session."""
if not self.api_key: if not self.api_key:
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.") raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.")
self._session = aiohttp.ClientSession( self._session = aiohttp.ClientSession(
headers={ headers={

204
tests/test_agent_config.py Normal file
View File

@@ -0,0 +1,204 @@
import os
from pathlib import Path
import pytest
os.environ.setdefault("LLM_API_KEY", "test-openai-key")
os.environ.setdefault("TTS_API_KEY", "test-tts-key")
os.environ.setdefault("ASR_API_KEY", "test-asr-key")
from app.config import load_settings
def _write_yaml(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
def _full_agent_yaml(llm_model: str = "gpt-4o-mini", llm_key: str = "test-openai-key") -> str:
return f"""
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.63
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
provider: openai_compatible
model: {llm_model}
temperature: 0.2
api_key: {llm_key}
api_url: https://example-llm.invalid/v1
tts:
provider: openai_compatible
api_key: test-tts-key
api_url: https://example-tts.invalid/v1/audio/speech
model: FunAudioLLM/CosyVoice2-0.5B
voice: anna
speed: 1.0
asr:
provider: openai_compatible
api_key: test-asr-key
api_url: https://example-asr.invalid/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: You are a strict test assistant.
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60
""".strip()
def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
config_dir = tmp_path / "config" / "agents"
_write_yaml(
config_dir / "support.yaml",
_full_agent_yaml(llm_model="gpt-4.1-mini"),
)
settings = load_settings(
argv=["--agent-profile", "support"],
)
assert settings.llm_model == "gpt-4.1-mini"
assert settings.llm_temperature == 0.2
assert settings.vad_threshold == 0.63
assert settings.agent_config_source == "cli_profile"
assert settings.agent_config_path == str((config_dir / "support.yaml").resolve())
def test_cli_path_has_higher_priority_than_env(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
env_file = tmp_path / "config" / "agents" / "env.yaml"
cli_file = tmp_path / "config" / "agents" / "cli.yaml"
_write_yaml(env_file, _full_agent_yaml(llm_model="env-model"))
_write_yaml(cli_file, _full_agent_yaml(llm_model="cli-model"))
monkeypatch.setenv("AGENT_CONFIG_PATH", str(env_file))
settings = load_settings(argv=["--agent-config", str(cli_file)])
assert settings.llm_model == "cli-model"
assert settings.agent_config_source == "cli_path"
assert settings.agent_config_path == str(cli_file.resolve())
def test_default_yaml_is_loaded_without_args_or_env(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
default_file = tmp_path / "config" / "agents" / "default.yaml"
_write_yaml(default_file, _full_agent_yaml(llm_model="from-default"))
monkeypatch.delenv("AGENT_CONFIG_PATH", raising=False)
monkeypatch.delenv("AGENT_PROFILE", raising=False)
settings = load_settings(argv=[])
assert settings.llm_model == "from-default"
assert settings.agent_config_source == "default"
assert settings.agent_config_path == str(default_file.resolve())
def test_missing_required_agent_settings_fail(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "missing-required.yaml"
_write_yaml(
file_path,
"""
agent:
llm:
model: gpt-4o-mini
""".strip(),
)
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_blank_required_provider_key_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "blank-key.yaml"
_write_yaml(file_path, _full_agent_yaml(llm_key=""))
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_missing_tts_api_url_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "missing-tts-url.yaml"
_write_yaml(
file_path,
_full_agent_yaml().replace(
" api_url: https://example-tts.invalid/v1/audio/speech\n",
"",
),
)
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_missing_asr_api_url_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "missing-asr-url.yaml"
_write_yaml(
file_path,
_full_agent_yaml().replace(
" api_url: https://example-asr.invalid/v1/audio/transcriptions\n",
"",
),
)
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
load_settings(argv=["--agent-config", str(file_path)])
def test_agent_yaml_unknown_key_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "bad-agent.yaml"
_write_yaml(file_path, _full_agent_yaml() + "\n unknown_option: true")
with pytest.raises(ValueError, match="Unknown agent config keys"):
load_settings(argv=["--agent-config", str(file_path)])
def test_legacy_siliconflow_section_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "legacy-siliconflow.yaml"
_write_yaml(
file_path,
"""
agent:
siliconflow:
api_key: x
""".strip(),
)
with pytest.raises(ValueError, match="Section 'siliconflow' is no longer supported"):
load_settings(argv=["--agent-config", str(file_path)])
def test_agent_yaml_missing_env_reference_fails(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
file_path = tmp_path / "bad-ref.yaml"
_write_yaml(
file_path,
_full_agent_yaml(llm_key="${UNSET_LLM_API_KEY}"),
)
with pytest.raises(ValueError, match="Missing environment variable"):
load_settings(argv=["--agent-config", str(file_path)])