Organize config
This commit is contained in:
64
.env.example
64
.env.example
@@ -23,57 +23,21 @@ CHUNK_SIZE_MS=20
|
|||||||
DEFAULT_CODEC=pcm
|
DEFAULT_CODEC=pcm
|
||||||
MAX_AUDIO_BUFFER_SECONDS=30
|
MAX_AUDIO_BUFFER_SECONDS=30
|
||||||
|
|
||||||
# VAD / EOU
|
# Agent profile selection (optional fallback when CLI args are not used)
|
||||||
VAD_TYPE=silero
|
# Prefer CLI:
|
||||||
VAD_MODEL_PATH=data/vad/silero_vad.onnx
|
# python -m app.main --agent-config config/agents/default.yaml
|
||||||
# Higher = stricter speech detection (fewer false positives, more misses).
|
# python -m app.main --agent-profile default
|
||||||
VAD_THRESHOLD=0.5
|
# AGENT_CONFIG_PATH=config/agents/default.yaml
|
||||||
# Require this much continuous speech before utterance can be valid.
|
# AGENT_PROFILE=default
|
||||||
VAD_MIN_SPEECH_DURATION_MS=100
|
AGENT_CONFIG_DIR=config/agents
|
||||||
# Silence duration required to finalize one user turn.
|
|
||||||
VAD_EOU_THRESHOLD_MS=800
|
|
||||||
|
|
||||||
# LLM
|
# Optional: provider credentials referenced from YAML, e.g. ${LLM_API_KEY}
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
# LLM_API_KEY=your_llm_api_key_here
|
||||||
# Optional for OpenAI-compatible providers.
|
# LLM_API_URL=https://api.openai.com/v1
|
||||||
# OPENAI_API_URL=https://api.openai.com/v1
|
# TTS_API_KEY=your_tts_api_key_here
|
||||||
LLM_MODEL=gpt-4o-mini
|
# TTS_API_URL=https://api.example.com/v1/audio/speech
|
||||||
LLM_TEMPERATURE=0.7
|
# ASR_API_KEY=your_asr_api_key_here
|
||||||
|
# ASR_API_URL=https://api.example.com/v1/audio/transcriptions
|
||||||
# TTS
|
|
||||||
# edge: no API key needed
|
|
||||||
# openai_compatible: compatible with SiliconFlow-style endpoints
|
|
||||||
TTS_PROVIDER=openai_compatible
|
|
||||||
TTS_VOICE=anna
|
|
||||||
TTS_SPEED=1.0
|
|
||||||
|
|
||||||
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
|
|
||||||
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
|
|
||||||
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
|
|
||||||
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
|
|
||||||
|
|
||||||
# ASR
|
|
||||||
ASR_PROVIDER=openai_compatible
|
|
||||||
# Interim cadence and minimum audio before interim decode.
|
|
||||||
ASR_INTERIM_INTERVAL_MS=500
|
|
||||||
ASR_MIN_AUDIO_MS=300
|
|
||||||
# ASR start gate: ignore micro-noise, then commit to one turn once started.
|
|
||||||
ASR_START_MIN_SPEECH_MS=160
|
|
||||||
# Pre-roll protects beginning phonemes.
|
|
||||||
ASR_PRE_SPEECH_MS=240
|
|
||||||
# Tail silence protects ending phonemes.
|
|
||||||
ASR_FINAL_TAIL_MS=120
|
|
||||||
|
|
||||||
# Duplex behavior
|
|
||||||
DUPLEX_ENABLED=true
|
|
||||||
# DUPLEX_GREETING=Hello! How can I help you today?
|
|
||||||
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
|
||||||
|
|
||||||
# Barge-in (user interrupting assistant)
|
|
||||||
# Min user speech duration needed to interrupt assistant audio.
|
|
||||||
BARGE_IN_MIN_DURATION_MS=200
|
|
||||||
# Allowed silence during potential barge-in (ms) before reset.
|
|
||||||
BARGE_IN_SILENCE_TOLERANCE_MS=60
|
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
|
|||||||
26
README.md
26
README.md
@@ -14,6 +14,30 @@ It is currently in an early, experimental stage.
|
|||||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
使用 agent profile(推荐)
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m app.main --agent-profile default
|
||||||
|
```
|
||||||
|
|
||||||
|
使用指定 YAML
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m app.main --agent-config config/agents/default.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Agent 配置路径优先级
|
||||||
|
1. `--agent-config`
|
||||||
|
2. `--agent-profile`(映射到 `config/agents/<profile>.yaml`)
|
||||||
|
3. `AGENT_CONFIG_PATH`
|
||||||
|
4. `AGENT_PROFILE`
|
||||||
|
5. `config/agents/default.yaml`(若存在)
|
||||||
|
|
||||||
|
说明
|
||||||
|
- Agent 相关配置是严格模式:YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。
|
||||||
|
- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。
|
||||||
|
- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。
|
||||||
|
|
||||||
测试
|
测试
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -28,4 +52,4 @@ python mic_client.py
|
|||||||
|
|
||||||
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
|
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
|
||||||
|
|
||||||
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.
|
See `docs/ws_v1_schema.md`.
|
||||||
|
|||||||
382
app/config.py
382
app/config.py
@@ -1,9 +1,354 @@
|
|||||||
"""Configuration management using Pydantic settings."""
|
"""Configuration management using Pydantic settings and agent YAML profiles."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
import json
|
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ImportError: # pragma: no cover - validated when agent YAML is used
|
||||||
|
yaml = None
|
||||||
|
|
||||||
|
|
||||||
|
_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}")
|
||||||
|
_DEFAULT_AGENT_CONFIG_DIR = "config/agents"
|
||||||
|
_DEFAULT_AGENT_CONFIG_FILE = "default.yaml"
|
||||||
|
_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
|
||||||
|
"vad": {
|
||||||
|
"type": "vad_type",
|
||||||
|
"model_path": "vad_model_path",
|
||||||
|
"threshold": "vad_threshold",
|
||||||
|
"min_speech_duration_ms": "vad_min_speech_duration_ms",
|
||||||
|
"eou_threshold_ms": "vad_eou_threshold_ms",
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"provider": "llm_provider",
|
||||||
|
"model": "llm_model",
|
||||||
|
"temperature": "llm_temperature",
|
||||||
|
"api_key": "llm_api_key",
|
||||||
|
"api_url": "llm_api_url",
|
||||||
|
},
|
||||||
|
"tts": {
|
||||||
|
"provider": "tts_provider",
|
||||||
|
"api_key": "tts_api_key",
|
||||||
|
"api_url": "tts_api_url",
|
||||||
|
"model": "tts_model",
|
||||||
|
"voice": "tts_voice",
|
||||||
|
"speed": "tts_speed",
|
||||||
|
},
|
||||||
|
"asr": {
|
||||||
|
"provider": "asr_provider",
|
||||||
|
"api_key": "asr_api_key",
|
||||||
|
"api_url": "asr_api_url",
|
||||||
|
"model": "asr_model",
|
||||||
|
"interim_interval_ms": "asr_interim_interval_ms",
|
||||||
|
"min_audio_ms": "asr_min_audio_ms",
|
||||||
|
"start_min_speech_ms": "asr_start_min_speech_ms",
|
||||||
|
"pre_speech_ms": "asr_pre_speech_ms",
|
||||||
|
"final_tail_ms": "asr_final_tail_ms",
|
||||||
|
},
|
||||||
|
"duplex": {
|
||||||
|
"enabled": "duplex_enabled",
|
||||||
|
"greeting": "duplex_greeting",
|
||||||
|
"system_prompt": "duplex_system_prompt",
|
||||||
|
},
|
||||||
|
"barge_in": {
|
||||||
|
"min_duration_ms": "barge_in_min_duration_ms",
|
||||||
|
"silence_tolerance_ms": "barge_in_silence_tolerance_ms",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_AGENT_SETTING_KEYS = {
|
||||||
|
"vad_type",
|
||||||
|
"vad_model_path",
|
||||||
|
"vad_threshold",
|
||||||
|
"vad_min_speech_duration_ms",
|
||||||
|
"vad_eou_threshold_ms",
|
||||||
|
"llm_provider",
|
||||||
|
"llm_api_key",
|
||||||
|
"llm_api_url",
|
||||||
|
"llm_model",
|
||||||
|
"llm_temperature",
|
||||||
|
"tts_provider",
|
||||||
|
"tts_api_key",
|
||||||
|
"tts_api_url",
|
||||||
|
"tts_model",
|
||||||
|
"tts_voice",
|
||||||
|
"tts_speed",
|
||||||
|
"asr_provider",
|
||||||
|
"asr_api_key",
|
||||||
|
"asr_api_url",
|
||||||
|
"asr_model",
|
||||||
|
"asr_interim_interval_ms",
|
||||||
|
"asr_min_audio_ms",
|
||||||
|
"asr_start_min_speech_ms",
|
||||||
|
"asr_pre_speech_ms",
|
||||||
|
"asr_final_tail_ms",
|
||||||
|
"duplex_enabled",
|
||||||
|
"duplex_greeting",
|
||||||
|
"duplex_system_prompt",
|
||||||
|
"barge_in_min_duration_ms",
|
||||||
|
"barge_in_silence_tolerance_ms",
|
||||||
|
}
|
||||||
|
_BASE_REQUIRED_AGENT_SETTING_KEYS = {
|
||||||
|
"vad_type",
|
||||||
|
"vad_model_path",
|
||||||
|
"vad_threshold",
|
||||||
|
"vad_min_speech_duration_ms",
|
||||||
|
"vad_eou_threshold_ms",
|
||||||
|
"llm_provider",
|
||||||
|
"llm_model",
|
||||||
|
"llm_temperature",
|
||||||
|
"tts_provider",
|
||||||
|
"tts_voice",
|
||||||
|
"tts_speed",
|
||||||
|
"asr_provider",
|
||||||
|
"asr_interim_interval_ms",
|
||||||
|
"asr_min_audio_ms",
|
||||||
|
"asr_start_min_speech_ms",
|
||||||
|
"asr_pre_speech_ms",
|
||||||
|
"asr_final_tail_ms",
|
||||||
|
"duplex_enabled",
|
||||||
|
"duplex_system_prompt",
|
||||||
|
"barge_in_min_duration_ms",
|
||||||
|
"barge_in_silence_tolerance_ms",
|
||||||
|
}
|
||||||
|
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
|
||||||
|
return str(overrides.get(key) or default).strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_blank(value: Any) -> bool:
|
||||||
|
return value is None or (isinstance(value, str) and not value.strip())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentConfigSelection:
|
||||||
|
"""Resolved agent config location and how it was selected."""
|
||||||
|
|
||||||
|
path: Optional[Path]
|
||||||
|
source: str
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Parse only agent-related CLI flags from argv."""
|
||||||
|
config_path: Optional[str] = None
|
||||||
|
profile: Optional[str] = None
|
||||||
|
i = 0
|
||||||
|
while i < len(argv):
|
||||||
|
arg = argv[i]
|
||||||
|
if arg.startswith("--agent-config="):
|
||||||
|
config_path = arg.split("=", 1)[1].strip() or None
|
||||||
|
elif arg == "--agent-config" and i + 1 < len(argv):
|
||||||
|
config_path = argv[i + 1].strip() or None
|
||||||
|
i += 1
|
||||||
|
elif arg.startswith("--agent-profile="):
|
||||||
|
profile = arg.split("=", 1)[1].strip() or None
|
||||||
|
elif arg == "--agent-profile" and i + 1 < len(argv):
|
||||||
|
profile = argv[i + 1].strip() or None
|
||||||
|
i += 1
|
||||||
|
i += 1
|
||||||
|
return config_path, profile
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_config_dir() -> Path:
|
||||||
|
base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR))
|
||||||
|
if not base_dir.is_absolute():
|
||||||
|
base_dir = Path.cwd() / base_dir
|
||||||
|
return base_dir.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_agent_selection(
|
||||||
|
agent_config_path: Optional[str] = None,
|
||||||
|
agent_profile: Optional[str] = None,
|
||||||
|
argv: Optional[List[str]] = None,
|
||||||
|
) -> AgentConfigSelection:
|
||||||
|
cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:]))
|
||||||
|
path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH")
|
||||||
|
profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE")
|
||||||
|
source = "none"
|
||||||
|
candidate: Optional[Path] = None
|
||||||
|
|
||||||
|
if path_value:
|
||||||
|
source = "cli_path" if (agent_config_path or cli_path) else "env_path"
|
||||||
|
candidate = Path(path_value)
|
||||||
|
elif profile_value:
|
||||||
|
source = "cli_profile" if (agent_profile or cli_profile) else "env_profile"
|
||||||
|
candidate = _agent_config_dir() / f"{profile_value}.yaml"
|
||||||
|
else:
|
||||||
|
fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE
|
||||||
|
if fallback.exists():
|
||||||
|
source = "default"
|
||||||
|
candidate = fallback
|
||||||
|
|
||||||
|
if candidate is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Agent YAML config is required. Provide --agent-config/--agent-profile "
|
||||||
|
"or create config/agents/default.yaml."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidate.is_absolute():
|
||||||
|
candidate = (Path.cwd() / candidate).resolve()
|
||||||
|
else:
|
||||||
|
candidate = candidate.resolve()
|
||||||
|
|
||||||
|
if not candidate.exists():
|
||||||
|
raise ValueError(f"Agent config file not found ({source}): {candidate}")
|
||||||
|
if not candidate.is_file():
|
||||||
|
raise ValueError(f"Agent config path is not a file: {candidate}")
|
||||||
|
return AgentConfigSelection(path=candidate, source=source)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_env_refs(value: Any) -> Any:
|
||||||
|
"""Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively."""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: _resolve_env_refs(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_resolve_env_refs(item) for item in value]
|
||||||
|
if not isinstance(value, str) or "${" not in value:
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _replace(match: re.Match[str]) -> str:
|
||||||
|
env_key = match.group(1)
|
||||||
|
default_value = match.group(2)
|
||||||
|
env_value = os.getenv(env_key)
|
||||||
|
if env_value is None:
|
||||||
|
if default_value is None:
|
||||||
|
raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}")
|
||||||
|
return default_value
|
||||||
|
return env_value
|
||||||
|
|
||||||
|
return _ENV_REF_PATTERN.sub(_replace, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Normalize YAML into flat Settings fields."""
|
||||||
|
normalized: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
for key, value in raw.items():
|
||||||
|
if key == "siliconflow":
|
||||||
|
raise ValueError(
|
||||||
|
"Section 'siliconflow' is no longer supported. "
|
||||||
|
"Move provider-specific fields into agent.llm / agent.asr / agent.tts."
|
||||||
|
)
|
||||||
|
section_map = _AGENT_SECTION_KEY_MAP.get(key)
|
||||||
|
if section_map is None:
|
||||||
|
normalized[key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"Agent config section '{key}' must be a mapping")
|
||||||
|
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
mapped_key = section_map.get(nested_key)
|
||||||
|
if mapped_key is None:
|
||||||
|
raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'")
|
||||||
|
normalized[mapped_key] = nested_value
|
||||||
|
|
||||||
|
unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS)
|
||||||
|
if unknown_keys:
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown agent config keys in YAML: "
|
||||||
|
+ ", ".join(unknown_keys)
|
||||||
|
)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
|
||||||
|
missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides))
|
||||||
|
string_required = {
|
||||||
|
"vad_type",
|
||||||
|
"vad_model_path",
|
||||||
|
"llm_provider",
|
||||||
|
"llm_model",
|
||||||
|
"tts_provider",
|
||||||
|
"tts_voice",
|
||||||
|
"asr_provider",
|
||||||
|
"duplex_system_prompt",
|
||||||
|
}
|
||||||
|
for key in string_required:
|
||||||
|
if key in overrides and _is_blank(overrides.get(key)):
|
||||||
|
missing.add(key)
|
||||||
|
|
||||||
|
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
|
||||||
|
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
|
||||||
|
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
|
||||||
|
missing.add("llm_api_key")
|
||||||
|
|
||||||
|
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
|
||||||
|
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
||||||
|
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
|
||||||
|
missing.add("tts_api_key")
|
||||||
|
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
|
||||||
|
missing.add("tts_api_url")
|
||||||
|
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
|
||||||
|
missing.add("tts_model")
|
||||||
|
|
||||||
|
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
|
||||||
|
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
||||||
|
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
|
||||||
|
missing.add("asr_api_key")
|
||||||
|
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
|
||||||
|
missing.add("asr_api_url")
|
||||||
|
if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")):
|
||||||
|
missing.add("asr_model")
|
||||||
|
|
||||||
|
return sorted(missing)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]:
|
||||||
|
if yaml is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"PyYAML is required for agent YAML configuration. Install with: pip install pyyaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
with selection.path.open("r", encoding="utf-8") as file:
|
||||||
|
raw = yaml.safe_load(file) or {}
|
||||||
|
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise ValueError(f"Agent config must be a YAML mapping: {selection.path}")
|
||||||
|
|
||||||
|
if "agent" in raw:
|
||||||
|
agent_value = raw["agent"]
|
||||||
|
if not isinstance(agent_value, dict):
|
||||||
|
raise ValueError("The 'agent' key in YAML must be a mapping")
|
||||||
|
raw = agent_value
|
||||||
|
|
||||||
|
resolved = _resolve_env_refs(raw)
|
||||||
|
overrides = _normalize_agent_overrides(resolved)
|
||||||
|
missing_required = _missing_required_keys(overrides)
|
||||||
|
if missing_required:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required agent settings in YAML ({selection.path}): "
|
||||||
|
+ ", ".join(missing_required)
|
||||||
|
)
|
||||||
|
|
||||||
|
overrides["agent_config_path"] = str(selection.path)
|
||||||
|
overrides["agent_config_source"] = selection.source
|
||||||
|
return overrides
|
||||||
|
|
||||||
|
|
||||||
|
def load_settings(
|
||||||
|
agent_config_path: Optional[str] = None,
|
||||||
|
agent_profile: Optional[str] = None,
|
||||||
|
argv: Optional[List[str]] = None,
|
||||||
|
) -> "Settings":
|
||||||
|
"""Load settings from .env and optional agent YAML."""
|
||||||
|
selection = _resolve_agent_selection(
|
||||||
|
agent_config_path=agent_config_path,
|
||||||
|
agent_profile=agent_profile,
|
||||||
|
argv=argv,
|
||||||
|
)
|
||||||
|
agent_overrides = _load_agent_overrides(selection)
|
||||||
|
return Settings(**agent_overrides)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -37,30 +382,35 @@ class Settings(BaseSettings):
|
|||||||
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
|
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
|
||||||
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
|
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
|
||||||
|
|
||||||
# OpenAI / LLM Configuration
|
# LLM Configuration
|
||||||
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
|
llm_provider: str = Field(
|
||||||
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
|
default="openai",
|
||||||
|
description="LLM provider (openai, openai_compatible, siliconflow)"
|
||||||
|
)
|
||||||
|
llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key")
|
||||||
|
llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL")
|
||||||
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
||||||
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
||||||
|
|
||||||
# TTS Configuration
|
# TTS Configuration
|
||||||
tts_provider: str = Field(
|
tts_provider: str = Field(
|
||||||
default="openai_compatible",
|
default="openai_compatible",
|
||||||
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
|
description="TTS provider (edge, openai_compatible, siliconflow)"
|
||||||
)
|
)
|
||||||
|
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
|
||||||
|
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
|
||||||
|
tts_model: Optional[str] = Field(default=None, description="TTS model name")
|
||||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||||
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||||
|
|
||||||
# SiliconFlow Configuration
|
|
||||||
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
|
|
||||||
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
|
|
||||||
|
|
||||||
# ASR Configuration
|
# ASR Configuration
|
||||||
asr_provider: str = Field(
|
asr_provider: str = Field(
|
||||||
default="openai_compatible",
|
default="openai_compatible",
|
||||||
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
|
description="ASR provider (openai_compatible, buffered, siliconflow)"
|
||||||
)
|
)
|
||||||
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
|
asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key")
|
||||||
|
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
|
||||||
|
asr_model: Optional[str] = Field(default=None, description="ASR model name")
|
||||||
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
||||||
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
||||||
asr_start_min_speech_ms: int = Field(
|
asr_start_min_speech_ms: int = Field(
|
||||||
@@ -122,6 +472,10 @@ class Settings(BaseSettings):
|
|||||||
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
|
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
|
||||||
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
|
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
|
||||||
|
|
||||||
|
# Agent YAML metadata
|
||||||
|
agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path")
|
||||||
|
agent_config_source: str = Field(default="none", description="How the agent YAML was selected")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_size_bytes(self) -> int:
|
def chunk_size_bytes(self) -> int:
|
||||||
"""Calculate chunk size in bytes based on sample rate and duration."""
|
"""Calculate chunk size in bytes based on sample rate and duration."""
|
||||||
@@ -146,7 +500,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
|
|
||||||
# Global settings instance
|
# Global settings instance
|
||||||
settings = Settings()
|
settings = load_settings()
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
|
|||||||
@@ -357,6 +357,12 @@ async def startup_event():
|
|||||||
logger.info(f"Server: {settings.host}:{settings.port}")
|
logger.info(f"Server: {settings.host}:{settings.port}")
|
||||||
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
||||||
logger.info(f"VAD model: {settings.vad_model_path}")
|
logger.info(f"VAD model: {settings.vad_model_path}")
|
||||||
|
if settings.agent_config_path:
|
||||||
|
logger.info(
|
||||||
|
f"Agent config loaded ({settings.agent_config_source}): {settings.agent_config_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Agent config: none (using .env/default agent values)")
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
|
|||||||
50
config/agents/example.yaml
Normal file
50
config/agents/example.yaml
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Agent behavior configuration (safe to edit per profile)
|
||||||
|
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
|
||||||
|
# Infra/server/network settings should stay in .env.
|
||||||
|
|
||||||
|
agent:
|
||||||
|
vad:
|
||||||
|
type: silero
|
||||||
|
model_path: data/vad/silero_vad.onnx
|
||||||
|
threshold: 0.5
|
||||||
|
min_speech_duration_ms: 100
|
||||||
|
eou_threshold_ms: 800
|
||||||
|
|
||||||
|
llm:
|
||||||
|
# provider: openai | openai_compatible | siliconflow
|
||||||
|
provider: openai_compatible
|
||||||
|
model: deepseek-v3
|
||||||
|
temperature: 0.7
|
||||||
|
# Required: no fallback. You can still reference env explicitly.
|
||||||
|
api_key: your_llm_api_key
|
||||||
|
# Optional for OpenAI-compatible endpoints:
|
||||||
|
api_url: https://api.qnaigc.com/v1
|
||||||
|
|
||||||
|
tts:
|
||||||
|
# provider: edge | openai_compatible | siliconflow
|
||||||
|
provider: openai_compatible
|
||||||
|
api_key: your_tts_api_key
|
||||||
|
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||||
|
model: FunAudioLLM/CosyVoice2-0.5B
|
||||||
|
voice: anna
|
||||||
|
speed: 1.0
|
||||||
|
|
||||||
|
asr:
|
||||||
|
# provider: buffered | openai_compatible | siliconflow
|
||||||
|
provider: openai_compatible
|
||||||
|
api_key: you_asr_api_key
|
||||||
|
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
|
||||||
|
model: FunAudioLLM/SenseVoiceSmall
|
||||||
|
interim_interval_ms: 500
|
||||||
|
min_audio_ms: 300
|
||||||
|
start_min_speech_ms: 160
|
||||||
|
pre_speech_ms: 240
|
||||||
|
final_tail_ms: 120
|
||||||
|
|
||||||
|
duplex:
|
||||||
|
enabled: true
|
||||||
|
system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||||
|
|
||||||
|
barge_in:
|
||||||
|
min_duration_ms: 200
|
||||||
|
silence_tolerance_ms: 60
|
||||||
@@ -310,7 +310,12 @@ class DuplexPipeline:
|
|||||||
|
|
||||||
def resolved_runtime_config(self) -> Dict[str, Any]:
|
def resolved_runtime_config(self) -> Dict[str, Any]:
|
||||||
"""Return current effective runtime configuration without secrets."""
|
"""Return current effective runtime configuration without secrets."""
|
||||||
llm_provider = str(self._runtime_llm.get("provider") or "openai").lower()
|
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||||||
|
llm_base_url = (
|
||||||
|
self._runtime_llm.get("baseUrl")
|
||||||
|
or settings.llm_api_url
|
||||||
|
or self._default_llm_base_url(llm_provider)
|
||||||
|
)
|
||||||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||||
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||||||
@@ -323,18 +328,18 @@ class DuplexPipeline:
|
|||||||
"llm": {
|
"llm": {
|
||||||
"provider": llm_provider,
|
"provider": llm_provider,
|
||||||
"model": str(self._runtime_llm.get("model") or settings.llm_model),
|
"model": str(self._runtime_llm.get("model") or settings.llm_model),
|
||||||
"baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url,
|
"baseUrl": llm_base_url,
|
||||||
},
|
},
|
||||||
"asr": {
|
"asr": {
|
||||||
"provider": asr_provider,
|
"provider": asr_provider,
|
||||||
"model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model),
|
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
|
||||||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||||||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||||||
},
|
},
|
||||||
"tts": {
|
"tts": {
|
||||||
"enabled": self._tts_output_enabled(),
|
"enabled": self._tts_output_enabled(),
|
||||||
"provider": tts_provider,
|
"provider": tts_provider,
|
||||||
"model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model),
|
"model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
|
||||||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||||||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||||||
},
|
},
|
||||||
@@ -452,6 +457,18 @@ class DuplexPipeline:
|
|||||||
normalized = str(provider or "").strip().lower()
|
normalized = str(provider or "").strip().lower()
|
||||||
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
|
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_llm_provider_supported(provider: Any) -> bool:
|
||||||
|
normalized = str(provider or "").strip().lower()
|
||||||
|
return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_llm_base_url(provider: Any) -> Optional[str]:
|
||||||
|
normalized = str(provider or "").strip().lower()
|
||||||
|
if normalized == "siliconflow":
|
||||||
|
return "https://api.siliconflow.cn/v1"
|
||||||
|
return None
|
||||||
|
|
||||||
def _tts_output_enabled(self) -> bool:
|
def _tts_output_enabled(self) -> bool:
|
||||||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||||||
if enabled is not None:
|
if enabled is not None:
|
||||||
@@ -527,12 +544,16 @@ class DuplexPipeline:
|
|||||||
try:
|
try:
|
||||||
# Connect LLM service
|
# Connect LLM service
|
||||||
if not self.llm_service:
|
if not self.llm_service:
|
||||||
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
|
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||||||
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
|
llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key
|
||||||
|
llm_base_url = (
|
||||||
|
self._runtime_llm.get("baseUrl")
|
||||||
|
or settings.llm_api_url
|
||||||
|
or self._default_llm_base_url(llm_provider)
|
||||||
|
)
|
||||||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||||||
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
|
|
||||||
|
|
||||||
if llm_provider == "openai" and llm_api_key:
|
if self._is_llm_provider_supported(llm_provider) and llm_api_key:
|
||||||
self.llm_service = OpenAILLMService(
|
self.llm_service = OpenAILLMService(
|
||||||
api_key=llm_api_key,
|
api_key=llm_api_key,
|
||||||
base_url=llm_base_url,
|
base_url=llm_base_url,
|
||||||
@@ -540,7 +561,7 @@ class DuplexPipeline:
|
|||||||
knowledge_config=self._resolved_knowledge_config(),
|
knowledge_config=self._resolved_knowledge_config(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("No OpenAI API key - using mock LLM")
|
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
|
||||||
self.llm_service = MockLLMService()
|
self.llm_service = MockLLMService()
|
||||||
|
|
||||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||||
@@ -556,20 +577,22 @@ class DuplexPipeline:
|
|||||||
if tts_output_enabled:
|
if tts_output_enabled:
|
||||||
if not self.tts_service:
|
if not self.tts_service:
|
||||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key
|
||||||
|
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
|
||||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||||
|
|
||||||
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
||||||
self.tts_service = OpenAICompatibleTTSService(
|
self.tts_service = OpenAICompatibleTTSService(
|
||||||
api_key=tts_api_key,
|
api_key=tts_api_key,
|
||||||
|
api_url=tts_api_url,
|
||||||
voice=tts_voice,
|
voice=tts_voice,
|
||||||
model=tts_model,
|
model=tts_model or "FunAudioLLM/CosyVoice2-0.5B",
|
||||||
sample_rate=settings.sample_rate,
|
sample_rate=settings.sample_rate,
|
||||||
speed=tts_speed
|
speed=tts_speed
|
||||||
)
|
)
|
||||||
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)")
|
logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})")
|
||||||
else:
|
else:
|
||||||
self.tts_service = EdgeTTSService(
|
self.tts_service = EdgeTTSService(
|
||||||
voice=tts_voice,
|
voice=tts_voice,
|
||||||
@@ -592,21 +615,23 @@ class DuplexPipeline:
|
|||||||
# Connect ASR service
|
# Connect ASR service
|
||||||
if not self.asr_service:
|
if not self.asr_service:
|
||||||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||||
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
|
asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key
|
||||||
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
|
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
|
||||||
|
asr_model = self._runtime_asr.get("model") or settings.asr_model
|
||||||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||||
|
|
||||||
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
|
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
|
||||||
self.asr_service = OpenAICompatibleASRService(
|
self.asr_service = OpenAICompatibleASRService(
|
||||||
api_key=asr_api_key,
|
api_key=asr_api_key,
|
||||||
model=asr_model,
|
api_url=asr_api_url,
|
||||||
|
model=asr_model or "FunAudioLLM/SenseVoiceSmall",
|
||||||
sample_rate=settings.sample_rate,
|
sample_rate=settings.sample_rate,
|
||||||
interim_interval_ms=asr_interim_interval,
|
interim_interval_ms=asr_interim_interval,
|
||||||
min_audio_for_interim_ms=asr_min_audio_ms,
|
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||||
on_transcript=self._on_transcript_callback
|
on_transcript=self._on_transcript_callback
|
||||||
)
|
)
|
||||||
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)")
|
logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})")
|
||||||
else:
|
else:
|
||||||
self.asr_service = BufferedASRService(
|
self.asr_service = BufferedASRService(
|
||||||
sample_rate=settings.sample_rate
|
sample_rate=settings.sample_rate
|
||||||
|
|||||||
520
docs/ws_v1_schema_zh.md
Normal file
520
docs/ws_v1_schema_zh.md
Normal file
@@ -0,0 +1,520 @@
|
|||||||
|
# WS v1 协议完整说明(中文)
|
||||||
|
|
||||||
|
本文档描述 `/ws` 端点的 WebSocket v1 协议,覆盖:
|
||||||
|
- 客户端输入(JSON 文本消息 + 二进制音频);
|
||||||
|
- 服务端输出(JSON 事件 + 二进制音频);
|
||||||
|
- 每个参数的类型、约束、含义与使用方式;
|
||||||
|
- 握手顺序、状态机、错误语义与实现细节。
|
||||||
|
|
||||||
|
实现对照来源:
|
||||||
|
- `models/ws_v1.py`
|
||||||
|
- `core/session.py`
|
||||||
|
- `core/duplex_pipeline.py`
|
||||||
|
- `app/main.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 传输与基础规则
|
||||||
|
|
||||||
|
- 连接地址:`ws://<host>/ws`
|
||||||
|
- 单连接双通道承载:
|
||||||
|
- 文本帧:JSON 控制消息(严格校验 schema)
|
||||||
|
- 二进制帧:原始 PCM 音频
|
||||||
|
- JSON 校验策略:
|
||||||
|
- 所有已定义客户端消息都 `extra="forbid"`,即不允许未声明字段;
|
||||||
|
- `hello.version` 固定必须是 `"v1"`;
|
||||||
|
- 缺失 `type` 或未知 `type` 会返回协议错误。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 状态机与消息顺序
|
||||||
|
|
||||||
|
### 2.1 服务端状态
|
||||||
|
|
||||||
|
- `WAIT_HELLO`:等待 `hello`
|
||||||
|
- `WAIT_START`:已通过握手,等待 `session.start`
|
||||||
|
- `ACTIVE`:会话运行中,可收发文本/音频
|
||||||
|
- `STOPPED`:会话结束
|
||||||
|
|
||||||
|
### 2.2 正确顺序
|
||||||
|
|
||||||
|
1. 客户端发送 `hello`
|
||||||
|
2. 服务端返回 `hello.ack`
|
||||||
|
3. 客户端发送 `session.start`
|
||||||
|
4. 服务端返回 `session.started`
|
||||||
|
5. 客户端可持续发送:
|
||||||
|
- 二进制音频
|
||||||
|
- `input.text`(可选)
|
||||||
|
- `response.cancel`(可选)
|
||||||
|
- `tool_call.results`(可选)
|
||||||
|
6. 客户端发送 `session.stop` 或直接断开连接
|
||||||
|
|
||||||
|
顺序错误会返回 `error`,`code = "protocol.order"`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 客户端 -> 服务端消息(输入)
|
||||||
|
|
||||||
|
## 3.1 `hello`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "hello",
|
||||||
|
"version": "v1",
|
||||||
|
"auth": {
|
||||||
|
"apiKey": "optional-api-key",
|
||||||
|
"jwt": "optional-jwt"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 |
|
||||||
|
| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 |
|
||||||
|
| `auth` | object \| null | 否 | 仅允许 `apiKey`、`jwt` | 认证载荷 | 认证策略由服务端配置决定 |
|
||||||
|
| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 |
|
||||||
|
| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 |
|
||||||
|
|
||||||
|
认证行为:
|
||||||
|
- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。
|
||||||
|
- 若 `WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY`:`auth.apiKey` 或 `auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。
|
||||||
|
|
||||||
|
## 3.2 `session.start`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "session.start",
|
||||||
|
"audio": {
|
||||||
|
"encoding": "pcm_s16le",
|
||||||
|
"sample_rate_hz": 16000,
|
||||||
|
"channels": 1
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"appId": "assistant_123",
|
||||||
|
"channel": "web",
|
||||||
|
"configVersionId": "cfg_20260217_01",
|
||||||
|
"client": "web-debug",
|
||||||
|
"output": {
|
||||||
|
"mode": "audio"
|
||||||
|
},
|
||||||
|
"systemPrompt": "你是简洁助手",
|
||||||
|
"greeting": "你好,我能帮你什么?"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| `type` | string | 是 | 固定 `"session.start"` | 启动会话 | 握手后第二阶段消息 |
|
||||||
|
| `audio` | object \| null | 否 | 仅支持固定值 | 音频格式描述 | 仅用于声明;MVP 实际只接受固定 PCM |
|
||||||
|
| `audio.encoding` | string | 否 | 固定 `"pcm_s16le"` | 编码格式 | 非该值会在模型校验层报错 |
|
||||||
|
| `audio.sample_rate_hz` | number | 否 | 固定 `16000` | 采样率 | 16kHz |
|
||||||
|
| `audio.channels` | number | 否 | 固定 `1` | 声道数 | 单声道 |
|
||||||
|
| `metadata` | object \| null | 否 | 任意对象(会被白名单过滤) | 运行时配置 | 用于 app/channel/提示词/输出模式等覆盖 |
|
||||||
|
|
||||||
|
`metadata` 白名单策略(关键):
|
||||||
|
- 允许透传的标识字段(ID 类):
|
||||||
|
- `appId` / `app_id`
|
||||||
|
- `channel`
|
||||||
|
- `configVersionId` / `config_version_id`
|
||||||
|
- 允许透传的覆盖字段:
|
||||||
|
- `firstTurnMode`
|
||||||
|
- `greeting`
|
||||||
|
- `generatedOpenerEnabled`
|
||||||
|
- `systemPrompt`
|
||||||
|
- `output`
|
||||||
|
- `bargeIn`
|
||||||
|
- `knowledge`
|
||||||
|
- `knowledgeBaseId`
|
||||||
|
- `history`
|
||||||
|
- `userId`
|
||||||
|
- `assistantId`
|
||||||
|
- `source`
|
||||||
|
- 客户端传入 `metadata.services` 会被忽略(服务端会记录 warning),服务配置由后端/环境变量决定。
|
||||||
|
|
||||||
|
`output.mode` 用法:
|
||||||
|
- `"audio"`(默认语音输出)
|
||||||
|
- `"text"`(纯文本输出)
|
||||||
|
- 纯文本模式下仍会收到 `assistant.response.delta/final`;
|
||||||
|
- 不会收到 TTS 音频帧与 `output.audio.start/end`。
|
||||||
|
|
||||||
|
## 3.3 `input.text`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "input.text",
|
||||||
|
"text": "你能做什么?"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| `type` | string | 是 | 固定 `"input.text"` | 文本输入 | 跳过 ASR,直接触发 LLM 回答 |
|
||||||
|
| `text` | string | 是 | 非空字符串为佳 | 用户文本 | 用于文本聊天或调试 |
|
||||||
|
|
||||||
|
## 3.4 `response.cancel`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "response.cancel",
|
||||||
|
"graceful": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 默认值 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 |
|
||||||
|
| `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 |
|
||||||
|
|
||||||
|
## 3.5 `tool_call.results`
|
||||||
|
|
||||||
|
仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_call.results",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"tool_call_id": "call_abc123",
|
||||||
|
"name": "weather",
|
||||||
|
"output": { "temp_c": 21, "condition": "sunny" },
|
||||||
|
"status": { "code": 200, "message": "ok" }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| `type` | string | 是 | 固定 `"tool_call.results"` | 工具执行回传 | 客户端工具结果上送 |
|
||||||
|
| `results` | array | 否 | 默认为空数组 | 多个工具结果 | 可批量回传 |
|
||||||
|
| `results[].tool_call_id` | string | 是 | 任意字符串 | 工具调用ID | 必须与 `assistant.tool_call.tool_call_id` 对应 |
|
||||||
|
| `results[].name` | string | 是 | 任意字符串 | 工具名 | 建议与请求一致 |
|
||||||
|
| `results[].output` | any | 否 | 任意 JSON | 工具输出 | 供模型后续组织回答 |
|
||||||
|
| `results[].status` | object | 是 | 包含 `code`、`message` | 执行状态 | 用于判定成功/失败 |
|
||||||
|
| `results[].status.code` | number | 是 | HTTP 风格状态码 | 状态码 | `200-299` 判定成功 |
|
||||||
|
| `results[].status.message` | string | 是 | 任意字符串 | 状态描述 | 例如 `"ok"` / `"timeout"` |
|
||||||
|
|
||||||
|
处理规则:
|
||||||
|
- 未请求过的 `tool_call_id` 会被忽略(防止伪造/串话);
|
||||||
|
- 重复回传会被忽略;
|
||||||
|
- 超时未回传会由服务端合成超时结果(`504`)。
|
||||||
|
|
||||||
|
## 3.6 `session.stop`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "session.stop",
|
||||||
|
"reason": "client_disconnect"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| `type` | string | 是 | 固定 `"session.stop"` | 结束会话 | 正常结束推荐发送 |
|
||||||
|
| `reason` | string \| null | 否 | 任意字符串 | 结束原因 | 服务端会回传到 `session.stopped.reason` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 二进制音频输入(客户端 -> 服务端)
|
||||||
|
|
||||||
|
在 `session.started` 之后可持续发送二进制音频。
|
||||||
|
|
||||||
|
固定格式(MVP):
|
||||||
|
- 编码:`pcm_s16le`
|
||||||
|
- 采样率:`16000`
|
||||||
|
- 声道:`1`
|
||||||
|
- 帧长:20ms = `640 bytes`
|
||||||
|
|
||||||
|
分包规则:
|
||||||
|
- 单个 WebSocket 二进制消息可包含 1 帧或多帧;
|
||||||
|
- 长度必须是 `640` 的整数倍;
|
||||||
|
- 不是 `640` 倍数会触发 `audio.frame_size_mismatch`,该消息整包丢弃;
|
||||||
|
- 奇数字节长度会触发 `audio.invalid_pcm`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 服务端 -> 客户端事件(输出)
|
||||||
|
|
||||||
|
所有 JSON 事件都包含统一包络字段。
|
||||||
|
|
||||||
|
## 5.1 统一包络(Envelope)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "event.name",
|
||||||
|
"timestamp": 1730000000000,
|
||||||
|
"sessionId": "sess_xxx",
|
||||||
|
"seq": 42,
|
||||||
|
"source": "asr",
|
||||||
|
"trackId": "audio_in",
|
||||||
|
"data": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 含义 | 使用说明 |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `type` | string | 事件类型 | 见下方事件清单 |
|
||||||
|
| `timestamp` | number | 事件时间戳(毫秒) | 由 `ev()` 生成 |
|
||||||
|
| `sessionId` | string | 会话ID | 同一连接固定 |
|
||||||
|
| `seq` | number | 单会话递增序号 | 可用于重放、去重、排序 |
|
||||||
|
| `source` | string | 事件来源 | 常见:`asr`/`llm`/`tts`/`tool`/`system`/`client`/`server` |
|
||||||
|
| `trackId` | string | 事件轨道 | 常用:`audio_in`/`audio_out`/`control` |
|
||||||
|
| `data` | object | 结构化数据 | 顶层业务字段会镜像进 `data` 以兼容旧客户端 |
|
||||||
|
|
||||||
|
关联ID(在 `data` 内自动注入,存在时):
|
||||||
|
- `turn_id`:一次用户-助手对话轮次
|
||||||
|
- `utterance_id`:一次用户语音话语
|
||||||
|
- `response_id`:一次助手生成响应
|
||||||
|
- `tool_call_id`:一次工具调用
|
||||||
|
- `tts_id`:一次 TTS 播放段
|
||||||
|
|
||||||
|
## 5.2 事件类型与参数
|
||||||
|
|
||||||
|
### 5.2.1 会话与控制类
|
||||||
|
|
||||||
|
1. `hello.ack`
|
||||||
|
- 关键字段:`version`
|
||||||
|
- 含义:握手成功,应紧接着发送 `session.start`
|
||||||
|
|
||||||
|
2. `session.started`
|
||||||
|
- 关键字段:
|
||||||
|
- `trackId`
|
||||||
|
- `tracks.audio_in`
|
||||||
|
- `tracks.audio_out`
|
||||||
|
- `tracks.control`
|
||||||
|
- `audio`(回显客户端声明的音频元信息)
|
||||||
|
- 含义:会话进入 ACTIVE,可发音频/文本
|
||||||
|
|
||||||
|
3. `config.resolved`
|
||||||
|
- 关键字段:
|
||||||
|
- `config.appId`
|
||||||
|
- `config.channel`
|
||||||
|
- `config.configVersionId`
|
||||||
|
- `config.prompt.sha256`
|
||||||
|
- `config.output`
|
||||||
|
- `config.services`(去密钥后的有效服务配置)
|
||||||
|
- `config.tools.allowlist`
|
||||||
|
- `config.tracks`
|
||||||
|
- 含义:服务端最终生效配置快照,便于前端展示与排错
|
||||||
|
|
||||||
|
4. `heartbeat`
|
||||||
|
- 关键字段:无业务字段(仅 envelope)
|
||||||
|
- 含义:保活心跳
|
||||||
|
- 默认间隔:`heartbeat_interval_sec`(默认 50s)
|
||||||
|
|
||||||
|
5. `session.stopped`
|
||||||
|
- 关键字段:`reason`
|
||||||
|
- 含义:会话结束确认
|
||||||
|
|
||||||
|
6. `error`
|
||||||
|
- 关键字段:
|
||||||
|
- `sender`
|
||||||
|
- `code`
|
||||||
|
- `message`
|
||||||
|
- `stage`
|
||||||
|
- `retryable`
|
||||||
|
- `trackId`
|
||||||
|
- `data.error`(结构化错误镜像)
|
||||||
|
- 含义:统一错误事件
|
||||||
|
|
||||||
|
### 5.2.2 识别与输入侧(ASR/VAD)
|
||||||
|
|
||||||
|
1. `input.speech_started`
|
||||||
|
- 字段:`probability`
|
||||||
|
- 含义:检测到语音开始
|
||||||
|
|
||||||
|
2. `input.speech_stopped`
|
||||||
|
- 字段:`probability`
|
||||||
|
- 含义:检测到语音结束
|
||||||
|
|
||||||
|
3. `transcript.delta`
|
||||||
|
- 字段:`text`
|
||||||
|
- 含义:ASR 增量识别文本(节流发送)
|
||||||
|
|
||||||
|
4. `transcript.final`
|
||||||
|
- 字段:`text`
|
||||||
|
- 含义:ASR 最终识别文本
|
||||||
|
|
||||||
|
### 5.2.3 输出侧(LLM/TTS/Tool)
|
||||||
|
|
||||||
|
1. `assistant.response.delta`
|
||||||
|
- 字段:`text`
|
||||||
|
- 含义:助手增量文本输出(节流发送)
|
||||||
|
|
||||||
|
2. `assistant.response.final`
|
||||||
|
- 字段:`text`
|
||||||
|
- 含义:助手完整文本输出
|
||||||
|
|
||||||
|
3. `assistant.tool_call`
|
||||||
|
- 字段:
|
||||||
|
- `tool_call_id`
|
||||||
|
- `tool_name`
|
||||||
|
- `arguments`(对象)
|
||||||
|
- `executor`(`client` 或 `server`)
|
||||||
|
- `timeout_ms`
|
||||||
|
- `tool_call`(完整工具调用对象)
|
||||||
|
- 含义:通知客户端发生工具调用(用于可视化或客户端执行)
|
||||||
|
|
||||||
|
4. `assistant.tool_result`
|
||||||
|
- 字段:
|
||||||
|
- `source`(`client` 或 `server`)
|
||||||
|
- `tool_call_id`
|
||||||
|
- `tool_name`
|
||||||
|
- `ok`(boolean)
|
||||||
|
- `error`(失败时 `{code,message,retryable}`)
|
||||||
|
- `result`(原始结果对象)
|
||||||
|
- 含义:工具调用结果回执
|
||||||
|
|
||||||
|
5. `output.audio.start`
|
||||||
|
- 含义:TTS 音频输出开始边界
|
||||||
|
|
||||||
|
6. `output.audio.end`
|
||||||
|
- 含义:TTS 音频输出结束边界
|
||||||
|
|
||||||
|
7. `response.interrupted`
|
||||||
|
- 含义:当前回答被打断(barge-in 或 cancel)
|
||||||
|
|
||||||
|
8. `metrics.ttfb`
|
||||||
|
- 字段:`latencyMs`
|
||||||
|
- 含义:首包音频时延(TTFB)
|
||||||
|
|
||||||
|
### 5.2.4 工作流扩展事件(可选)
|
||||||
|
|
||||||
|
若 `metadata.workflow` 生效,会额外出现:
|
||||||
|
- `workflow.started`
|
||||||
|
- `workflow.node.entered`
|
||||||
|
- `workflow.edge.taken`
|
||||||
|
- `workflow.tool.requested`
|
||||||
|
- `workflow.human_transfer`
|
||||||
|
- `workflow.ended`
|
||||||
|
|
||||||
|
这些事件用于外部可视化工作流状态,不影响基础语音会话协议。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 服务端二进制音频输出(服务端 -> 客户端)
|
||||||
|
|
||||||
|
- 音频为 PCM 二进制帧;
|
||||||
|
- 发送单位对齐到 `640 bytes`(不足会补零后发送);
|
||||||
|
- 前端通常结合 `output.audio.start/end` 做播放边界控制;
|
||||||
|
- 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 错误模型与常见错误码
|
||||||
|
|
||||||
|
统一结构(`error` 事件):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"sender": "client",
|
||||||
|
"code": "protocol.invalid_message",
|
||||||
|
"message": "Invalid message: ...",
|
||||||
|
"stage": "protocol",
|
||||||
|
"retryable": false,
|
||||||
|
"trackId": "control",
|
||||||
|
"data": {
|
||||||
|
"error": {
|
||||||
|
"stage": "protocol",
|
||||||
|
"code": "protocol.invalid_message",
|
||||||
|
"message": "Invalid message: ...",
|
||||||
|
"retryable": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
字段语义:
|
||||||
|
- `sender`:错误来源角色(如 `client` / `server` / `auth`)
|
||||||
|
- `code`:机器可读错误码
|
||||||
|
- `message`:人类可读描述
|
||||||
|
- `stage`:阶段(`protocol|audio|asr|llm|tts|tool`)
|
||||||
|
- `retryable`:是否建议重试
|
||||||
|
- `trackId`:错误归属轨道
|
||||||
|
|
||||||
|
常见错误码:
|
||||||
|
- `protocol.invalid_json`
|
||||||
|
- `protocol.invalid_message`
|
||||||
|
- `protocol.order`
|
||||||
|
- `protocol.version_unsupported`
|
||||||
|
- `protocol.unsupported`
|
||||||
|
- `auth.invalid_api_key`
|
||||||
|
- `auth.required`
|
||||||
|
- `audio.invalid_pcm`
|
||||||
|
- `audio.frame_size_mismatch`
|
||||||
|
- `audio.processing_failed`
|
||||||
|
- `server.internal`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 心跳与超时
|
||||||
|
|
||||||
|
服务端后台任务逻辑:
|
||||||
|
- 每隔约 5 秒检查一次连接;
|
||||||
|
- 超过 `inactivity_timeout_sec`(默认 60 秒)未收到任何客户端消息则关闭会话;
|
||||||
|
- 每隔 `heartbeat_interval_sec`(默认 50 秒)发送一次 `heartbeat`。
|
||||||
|
|
||||||
|
客户端建议:
|
||||||
|
- 持续上行音频或定期发送轻量文本消息,避免被判定闲置;
|
||||||
|
- 用 `heartbeat` + `seq` 检测连接活性和事件乱序。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. 实战接入建议
|
||||||
|
|
||||||
|
1. 建连后立即发送 `hello`,收到 `hello.ack` 后再发 `session.start`。
|
||||||
|
2. 语音输入严格按 16k/16bit/mono,并保证每个 WS 二进制消息长度是 `640*n`。
|
||||||
|
3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。
|
||||||
|
4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。
|
||||||
|
5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`。
|
||||||
|
6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. 最小完整时序示例
|
||||||
|
|
||||||
|
```text
|
||||||
|
Client -> hello
|
||||||
|
Server <- hello.ack
|
||||||
|
Client -> session.start
|
||||||
|
Server <- session.started
|
||||||
|
Server <- config.resolved
|
||||||
|
Client -> (binary pcm frames...)
|
||||||
|
Server <- input.speech_started / transcript.delta / transcript.final
|
||||||
|
Server <- assistant.response.delta / assistant.response.final
|
||||||
|
Server <- output.audio.start
|
||||||
|
Server <- (binary pcm frames...)
|
||||||
|
Server <- output.audio.end
|
||||||
|
Client -> session.stop
|
||||||
|
Server <- session.stopped
|
||||||
|
```
|
||||||
|
|
||||||
@@ -17,6 +17,7 @@ pydantic>=2.5.3
|
|||||||
pydantic-settings>=2.1.0
|
pydantic-settings>=2.1.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
toml>=0.10.2
|
toml>=0.10.2
|
||||||
|
pyyaml>=6.0.1
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
loguru>=0.7.2
|
loguru>=0.7.2
|
||||||
|
|||||||
@@ -43,14 +43,14 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
||||||
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars)
|
||||||
base_url: Custom API base URL (for Azure or compatible APIs)
|
base_url: Custom API base URL (for Azure or compatible APIs)
|
||||||
system_prompt: Default system prompt for conversations
|
system_prompt: Default system prompt for conversations
|
||||||
"""
|
"""
|
||||||
super().__init__(model=model)
|
super().__init__(model=model)
|
||||||
|
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY")
|
||||||
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
|
||||||
self.system_prompt = system_prompt or (
|
self.system_prompt = system_prompt or (
|
||||||
"You are a helpful, friendly voice assistant. "
|
"You are a helpful, friendly voice assistant. "
|
||||||
"Keep your responses concise and conversational. "
|
"Keep your responses concise and conversational. "
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import wave
|
import wave
|
||||||
from typing import AsyncIterator, Optional, Callable, Awaitable
|
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: Optional[str] = None,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
language: str = "auto",
|
language: str = "auto",
|
||||||
@@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: Provider API key
|
api_key: Provider API key
|
||||||
|
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||||
model: ASR model name or alias
|
model: ASR model name or alias
|
||||||
sample_rate: Audio sample rate (16000 recommended)
|
sample_rate: Audio sample rate (16000 recommended)
|
||||||
language: Language code (auto for automatic detection)
|
language: Language code (auto for automatic detection)
|
||||||
@@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
|||||||
if not AIOHTTP_AVAILABLE:
|
if not AIOHTTP_AVAILABLE:
|
||||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||||
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
|
||||||
|
self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
|
||||||
self.model = self.MODELS.get(model.lower(), model)
|
self.model = self.MODELS.get(model.lower(), model)
|
||||||
self.interim_interval_ms = interim_interval_ms
|
self.interim_interval_ms = interim_interval_ms
|
||||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||||
@@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
|||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Connect to the service."""
|
"""Connect to the service."""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||||
self._session = aiohttp.ClientSession(
|
self._session = aiohttp.ClientSession(
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
@@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
|||||||
)
|
)
|
||||||
form_data.add_field('model', self.model)
|
form_data.add_field('model', self.model)
|
||||||
|
|
||||||
async with self._session.post(self.API_URL, data=form_data) as response:
|
async with self._session.post(self.api_url, data=form_data) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
text = result.get("text", "").strip()
|
text = result.get("text", "").strip()
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
voice: str = "anna",
|
voice: str = "anna",
|
||||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
@@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
|||||||
Initialize OpenAI-compatible TTS service.
|
Initialize OpenAI-compatible TTS service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
|
api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars)
|
||||||
|
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||||
model: Model name
|
model: Model name
|
||||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||||
@@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
|||||||
|
|
||||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||||
|
|
||||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech"
|
||||||
|
|
||||||
self._session: Optional[aiohttp.ClientSession] = None
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
self._cancel_event = asyncio.Event()
|
self._cancel_event = asyncio.Event()
|
||||||
@@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
|||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Initialize HTTP session."""
|
"""Initialize HTTP session."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||||
|
|
||||||
self._session = aiohttp.ClientSession(
|
self._session = aiohttp.ClientSession(
|
||||||
headers={
|
headers={
|
||||||
|
|||||||
204
tests/test_agent_config.py
Normal file
204
tests/test_agent_config.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ.setdefault("LLM_API_KEY", "test-openai-key")
|
||||||
|
os.environ.setdefault("TTS_API_KEY", "test-tts-key")
|
||||||
|
os.environ.setdefault("ASR_API_KEY", "test-asr-key")
|
||||||
|
|
||||||
|
from app.config import load_settings
|
||||||
|
|
||||||
|
|
||||||
|
def _write_yaml(path: Path, content: str) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _full_agent_yaml(llm_model: str = "gpt-4o-mini", llm_key: str = "test-openai-key") -> str:
|
||||||
|
return f"""
|
||||||
|
agent:
|
||||||
|
vad:
|
||||||
|
type: silero
|
||||||
|
model_path: data/vad/silero_vad.onnx
|
||||||
|
threshold: 0.63
|
||||||
|
min_speech_duration_ms: 100
|
||||||
|
eou_threshold_ms: 800
|
||||||
|
|
||||||
|
llm:
|
||||||
|
provider: openai_compatible
|
||||||
|
model: {llm_model}
|
||||||
|
temperature: 0.2
|
||||||
|
api_key: {llm_key}
|
||||||
|
api_url: https://example-llm.invalid/v1
|
||||||
|
|
||||||
|
tts:
|
||||||
|
provider: openai_compatible
|
||||||
|
api_key: test-tts-key
|
||||||
|
api_url: https://example-tts.invalid/v1/audio/speech
|
||||||
|
model: FunAudioLLM/CosyVoice2-0.5B
|
||||||
|
voice: anna
|
||||||
|
speed: 1.0
|
||||||
|
|
||||||
|
asr:
|
||||||
|
provider: openai_compatible
|
||||||
|
api_key: test-asr-key
|
||||||
|
api_url: https://example-asr.invalid/v1/audio/transcriptions
|
||||||
|
model: FunAudioLLM/SenseVoiceSmall
|
||||||
|
interim_interval_ms: 500
|
||||||
|
min_audio_ms: 300
|
||||||
|
start_min_speech_ms: 160
|
||||||
|
pre_speech_ms: 240
|
||||||
|
final_tail_ms: 120
|
||||||
|
|
||||||
|
duplex:
|
||||||
|
enabled: true
|
||||||
|
system_prompt: You are a strict test assistant.
|
||||||
|
|
||||||
|
barge_in:
|
||||||
|
min_duration_ms: 200
|
||||||
|
silence_tolerance_ms: 60
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
config_dir = tmp_path / "config" / "agents"
|
||||||
|
_write_yaml(
|
||||||
|
config_dir / "support.yaml",
|
||||||
|
_full_agent_yaml(llm_model="gpt-4.1-mini"),
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = load_settings(
|
||||||
|
argv=["--agent-profile", "support"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert settings.llm_model == "gpt-4.1-mini"
|
||||||
|
assert settings.llm_temperature == 0.2
|
||||||
|
assert settings.vad_threshold == 0.63
|
||||||
|
assert settings.agent_config_source == "cli_profile"
|
||||||
|
assert settings.agent_config_path == str((config_dir / "support.yaml").resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_path_has_higher_priority_than_env(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
env_file = tmp_path / "config" / "agents" / "env.yaml"
|
||||||
|
cli_file = tmp_path / "config" / "agents" / "cli.yaml"
|
||||||
|
|
||||||
|
_write_yaml(env_file, _full_agent_yaml(llm_model="env-model"))
|
||||||
|
_write_yaml(cli_file, _full_agent_yaml(llm_model="cli-model"))
|
||||||
|
|
||||||
|
monkeypatch.setenv("AGENT_CONFIG_PATH", str(env_file))
|
||||||
|
|
||||||
|
settings = load_settings(argv=["--agent-config", str(cli_file)])
|
||||||
|
|
||||||
|
assert settings.llm_model == "cli-model"
|
||||||
|
assert settings.agent_config_source == "cli_path"
|
||||||
|
assert settings.agent_config_path == str(cli_file.resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_yaml_is_loaded_without_args_or_env(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
default_file = tmp_path / "config" / "agents" / "default.yaml"
|
||||||
|
_write_yaml(default_file, _full_agent_yaml(llm_model="from-default"))
|
||||||
|
|
||||||
|
monkeypatch.delenv("AGENT_CONFIG_PATH", raising=False)
|
||||||
|
monkeypatch.delenv("AGENT_PROFILE", raising=False)
|
||||||
|
|
||||||
|
settings = load_settings(argv=[])
|
||||||
|
|
||||||
|
assert settings.llm_model == "from-default"
|
||||||
|
assert settings.agent_config_source == "default"
|
||||||
|
assert settings.agent_config_path == str(default_file.resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_required_agent_settings_fail(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "missing-required.yaml"
|
||||||
|
_write_yaml(
|
||||||
|
file_path,
|
||||||
|
"""
|
||||||
|
agent:
|
||||||
|
llm:
|
||||||
|
model: gpt-4o-mini
|
||||||
|
""".strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_blank_required_provider_key_fails(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "blank-key.yaml"
|
||||||
|
_write_yaml(file_path, _full_agent_yaml(llm_key=""))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_tts_api_url_fails(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "missing-tts-url.yaml"
|
||||||
|
_write_yaml(
|
||||||
|
file_path,
|
||||||
|
_full_agent_yaml().replace(
|
||||||
|
" api_url: https://example-tts.invalid/v1/audio/speech\n",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_asr_api_url_fails(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "missing-asr-url.yaml"
|
||||||
|
_write_yaml(
|
||||||
|
file_path,
|
||||||
|
_full_agent_yaml().replace(
|
||||||
|
" api_url: https://example-asr.invalid/v1/audio/transcriptions\n",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_yaml_unknown_key_fails(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "bad-agent.yaml"
|
||||||
|
_write_yaml(file_path, _full_agent_yaml() + "\n unknown_option: true")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown agent config keys"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_siliconflow_section_fails(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "legacy-siliconflow.yaml"
|
||||||
|
_write_yaml(
|
||||||
|
file_path,
|
||||||
|
"""
|
||||||
|
agent:
|
||||||
|
siliconflow:
|
||||||
|
api_key: x
|
||||||
|
""".strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Section 'siliconflow' is no longer supported"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_yaml_missing_env_reference_fails(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
file_path = tmp_path / "bad-ref.yaml"
|
||||||
|
_write_yaml(
|
||||||
|
file_path,
|
||||||
|
_full_agent_yaml(llm_key="${UNSET_LLM_API_KEY}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing environment variable"):
|
||||||
|
load_settings(argv=["--agent-config", str(file_path)])
|
||||||
Reference in New Issue
Block a user