From 8b9064f6e6e2963bdef46805a578494d7607599e Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 25 Feb 2026 15:52:55 +0800 Subject: [PATCH] Organize config --- .env.example | 64 +--- README.md | 26 +- app/config.py | 382 +++++++++++++++++++++- app/main.py | 6 + config/agents/example.yaml | 50 +++ core/duplex_pipeline.py | 59 +++- docs/ws_v1_schema_zh.md | 520 ++++++++++++++++++++++++++++++ requirements.txt | 1 + services/llm.py | 6 +- services/openai_compatible_asr.py | 12 +- services/openai_compatible_tts.py | 10 +- tests/test_agent_config.py | 204 ++++++++++++ 12 files changed, 1248 insertions(+), 92 deletions(-) create mode 100644 config/agents/example.yaml create mode 100644 docs/ws_v1_schema_zh.md create mode 100644 tests/test_agent_config.py diff --git a/.env.example b/.env.example index f62a4c6..c812b48 100644 --- a/.env.example +++ b/.env.example @@ -23,57 +23,21 @@ CHUNK_SIZE_MS=20 DEFAULT_CODEC=pcm MAX_AUDIO_BUFFER_SECONDS=30 -# VAD / EOU -VAD_TYPE=silero -VAD_MODEL_PATH=data/vad/silero_vad.onnx -# Higher = stricter speech detection (fewer false positives, more misses). -VAD_THRESHOLD=0.5 -# Require this much continuous speech before utterance can be valid. -VAD_MIN_SPEECH_DURATION_MS=100 -# Silence duration required to finalize one user turn. -VAD_EOU_THRESHOLD_MS=800 +# Agent profile selection (optional fallback when CLI args are not used) +# Prefer CLI: +# python -m app.main --agent-config config/agents/default.yaml +# python -m app.main --agent-profile default +# AGENT_CONFIG_PATH=config/agents/default.yaml +# AGENT_PROFILE=default +AGENT_CONFIG_DIR=config/agents -# LLM -OPENAI_API_KEY=your_openai_api_key_here -# Optional for OpenAI-compatible providers. -# OPENAI_API_URL=https://api.openai.com/v1 -LLM_MODEL=gpt-4o-mini -LLM_TEMPERATURE=0.7 - -# TTS -# edge: no API key needed -# openai_compatible: compatible with SiliconFlow-style endpoints -TTS_PROVIDER=openai_compatible -TTS_VOICE=anna -TTS_SPEED=1.0 - -# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible) -SILICONFLOW_API_KEY=your_siliconflow_api_key_here -SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B -SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall - -# ASR -ASR_PROVIDER=openai_compatible -# Interim cadence and minimum audio before interim decode. -ASR_INTERIM_INTERVAL_MS=500 -ASR_MIN_AUDIO_MS=300 -# ASR start gate: ignore micro-noise, then commit to one turn once started. -ASR_START_MIN_SPEECH_MS=160 -# Pre-roll protects beginning phonemes. -ASR_PRE_SPEECH_MS=240 -# Tail silence protects ending phonemes. -ASR_FINAL_TAIL_MS=120 - -# Duplex behavior -DUPLEX_ENABLED=true -# DUPLEX_GREETING=Hello! How can I help you today? -DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational. - -# Barge-in (user interrupting assistant) -# Min user speech duration needed to interrupt assistant audio. -BARGE_IN_MIN_DURATION_MS=200 -# Allowed silence during potential barge-in (ms) before reset. -BARGE_IN_SILENCE_TOLERANCE_MS=60 +# Optional: provider credentials referenced from YAML, e.g. ${LLM_API_KEY} +# LLM_API_KEY=your_llm_api_key_here +# LLM_API_URL=https://api.openai.com/v1 +# TTS_API_KEY=your_tts_api_key_here +# TTS_API_URL=https://api.example.com/v1/audio/speech +# ASR_API_KEY=your_asr_api_key_here +# ASR_API_URL=https://api.example.com/v1/audio/transcriptions # Logging LOG_LEVEL=INFO diff --git a/README.md b/README.md index 17d9e3a..de8f20f 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,30 @@ It is currently in an early, experimental stage. uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` +使用 agent profile(推荐) + +``` +python -m app.main --agent-profile default +``` + +使用指定 YAML + +``` +python -m app.main --agent-config config/agents/default.yaml +``` + +Agent 配置路径优先级 +1. `--agent-config` +2. `--agent-profile`(映射到 `config/agents/.yaml`) +3. `AGENT_CONFIG_PATH` +4. `AGENT_PROFILE` +5. `config/agents/default.yaml`(若存在) + +说明 +- Agent 相关配置是严格模式:YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。 +- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。 +- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。 + 测试 ``` @@ -28,4 +52,4 @@ python mic_client.py `/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames. -See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`. +See `docs/ws_v1_schema.md`. diff --git a/app/config.py b/app/config.py index 1e3e1b3..7eaf74d 100644 --- a/app/config.py +++ b/app/config.py @@ -1,9 +1,354 @@ -"""Configuration management using Pydantic settings.""" +"""Configuration management using Pydantic settings and agent YAML profiles.""" + +import json +import os +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple -from typing import List, Optional from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict -import json + +try: + import yaml +except ImportError: # pragma: no cover - validated when agent YAML is used + yaml = None + + +_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}") +_DEFAULT_AGENT_CONFIG_DIR = "config/agents" +_DEFAULT_AGENT_CONFIG_FILE = "default.yaml" +_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = { + "vad": { + "type": "vad_type", + "model_path": "vad_model_path", + "threshold": "vad_threshold", + "min_speech_duration_ms": "vad_min_speech_duration_ms", + "eou_threshold_ms": "vad_eou_threshold_ms", + }, + "llm": { + "provider": "llm_provider", + "model": "llm_model", + "temperature": "llm_temperature", + "api_key": "llm_api_key", + "api_url": "llm_api_url", + }, + "tts": { + "provider": "tts_provider", + "api_key": "tts_api_key", + "api_url": "tts_api_url", + "model": "tts_model", + "voice": "tts_voice", + "speed": "tts_speed", + }, + "asr": { + "provider": "asr_provider", + "api_key": "asr_api_key", + "api_url": "asr_api_url", + "model": "asr_model", + "interim_interval_ms": "asr_interim_interval_ms", + "min_audio_ms": "asr_min_audio_ms", + "start_min_speech_ms": "asr_start_min_speech_ms", + "pre_speech_ms": "asr_pre_speech_ms", + "final_tail_ms": "asr_final_tail_ms", + }, + "duplex": { + "enabled": "duplex_enabled", + "greeting": "duplex_greeting", + "system_prompt": "duplex_system_prompt", + }, + "barge_in": { + "min_duration_ms": "barge_in_min_duration_ms", + "silence_tolerance_ms": "barge_in_silence_tolerance_ms", + }, +} +_AGENT_SETTING_KEYS = { + "vad_type", + "vad_model_path", + "vad_threshold", + "vad_min_speech_duration_ms", + "vad_eou_threshold_ms", + "llm_provider", + "llm_api_key", + "llm_api_url", + "llm_model", + "llm_temperature", + "tts_provider", + "tts_api_key", + "tts_api_url", + "tts_model", + "tts_voice", + "tts_speed", + "asr_provider", + "asr_api_key", + "asr_api_url", + "asr_model", + "asr_interim_interval_ms", + "asr_min_audio_ms", + "asr_start_min_speech_ms", + "asr_pre_speech_ms", + "asr_final_tail_ms", + "duplex_enabled", + "duplex_greeting", + "duplex_system_prompt", + "barge_in_min_duration_ms", + "barge_in_silence_tolerance_ms", +} +_BASE_REQUIRED_AGENT_SETTING_KEYS = { + "vad_type", + "vad_model_path", + "vad_threshold", + "vad_min_speech_duration_ms", + "vad_eou_threshold_ms", + "llm_provider", + "llm_model", + "llm_temperature", + "tts_provider", + "tts_voice", + "tts_speed", + "asr_provider", + "asr_interim_interval_ms", + "asr_min_audio_ms", + "asr_start_min_speech_ms", + "asr_pre_speech_ms", + "asr_final_tail_ms", + "duplex_enabled", + "duplex_system_prompt", + "barge_in_min_duration_ms", + "barge_in_silence_tolerance_ms", +} +_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} + + +def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str: + return str(overrides.get(key) or default).strip().lower() + + +def _is_blank(value: Any) -> bool: + return value is None or (isinstance(value, str) and not value.strip()) + + +@dataclass(frozen=True) +class AgentConfigSelection: + """Resolved agent config location and how it was selected.""" + + path: Optional[Path] + source: str + + +def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]: + """Parse only agent-related CLI flags from argv.""" + config_path: Optional[str] = None + profile: Optional[str] = None + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith("--agent-config="): + config_path = arg.split("=", 1)[1].strip() or None + elif arg == "--agent-config" and i + 1 < len(argv): + config_path = argv[i + 1].strip() or None + i += 1 + elif arg.startswith("--agent-profile="): + profile = arg.split("=", 1)[1].strip() or None + elif arg == "--agent-profile" and i + 1 < len(argv): + profile = argv[i + 1].strip() or None + i += 1 + i += 1 + return config_path, profile + + +def _agent_config_dir() -> Path: + base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR)) + if not base_dir.is_absolute(): + base_dir = Path.cwd() / base_dir + return base_dir.resolve() + + +def _resolve_agent_selection( + agent_config_path: Optional[str] = None, + agent_profile: Optional[str] = None, + argv: Optional[List[str]] = None, +) -> AgentConfigSelection: + cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:])) + path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH") + profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE") + source = "none" + candidate: Optional[Path] = None + + if path_value: + source = "cli_path" if (agent_config_path or cli_path) else "env_path" + candidate = Path(path_value) + elif profile_value: + source = "cli_profile" if (agent_profile or cli_profile) else "env_profile" + candidate = _agent_config_dir() / f"{profile_value}.yaml" + else: + fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE + if fallback.exists(): + source = "default" + candidate = fallback + + if candidate is None: + raise ValueError( + "Agent YAML config is required. Provide --agent-config/--agent-profile " + "or create config/agents/default.yaml." + ) + + if not candidate.is_absolute(): + candidate = (Path.cwd() / candidate).resolve() + else: + candidate = candidate.resolve() + + if not candidate.exists(): + raise ValueError(f"Agent config file not found ({source}): {candidate}") + if not candidate.is_file(): + raise ValueError(f"Agent config path is not a file: {candidate}") + return AgentConfigSelection(path=candidate, source=source) + + +def _resolve_env_refs(value: Any) -> Any: + """Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively.""" + if isinstance(value, dict): + return {k: _resolve_env_refs(v) for k, v in value.items()} + if isinstance(value, list): + return [_resolve_env_refs(item) for item in value] + if not isinstance(value, str) or "${" not in value: + return value + + def _replace(match: re.Match[str]) -> str: + env_key = match.group(1) + default_value = match.group(2) + env_value = os.getenv(env_key) + if env_value is None: + if default_value is None: + raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}") + return default_value + return env_value + + return _ENV_REF_PATTERN.sub(_replace, value) + + +def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]: + """Normalize YAML into flat Settings fields.""" + normalized: Dict[str, Any] = {} + + for key, value in raw.items(): + if key == "siliconflow": + raise ValueError( + "Section 'siliconflow' is no longer supported. " + "Move provider-specific fields into agent.llm / agent.asr / agent.tts." + ) + section_map = _AGENT_SECTION_KEY_MAP.get(key) + if section_map is None: + normalized[key] = value + continue + + if not isinstance(value, dict): + raise ValueError(f"Agent config section '{key}' must be a mapping") + + for nested_key, nested_value in value.items(): + mapped_key = section_map.get(nested_key) + if mapped_key is None: + raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'") + normalized[mapped_key] = nested_value + + unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS) + if unknown_keys: + raise ValueError( + "Unknown agent config keys in YAML: " + + ", ".join(unknown_keys) + ) + return normalized + + +def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]: + missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides)) + string_required = { + "vad_type", + "vad_model_path", + "llm_provider", + "llm_model", + "tts_provider", + "tts_voice", + "asr_provider", + "duplex_system_prompt", + } + for key in string_required: + if key in overrides and _is_blank(overrides.get(key)): + missing.add(key) + + llm_provider = _normalized_provider(overrides, "llm_provider", "openai") + if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai": + if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")): + missing.add("llm_api_key") + + tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible") + if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS: + if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")): + missing.add("tts_api_key") + if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")): + missing.add("tts_api_url") + if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")): + missing.add("tts_model") + + asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible") + if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS: + if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")): + missing.add("asr_api_key") + if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")): + missing.add("asr_api_url") + if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")): + missing.add("asr_model") + + return sorted(missing) + + +def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]: + if yaml is None: + raise RuntimeError( + "PyYAML is required for agent YAML configuration. Install with: pip install pyyaml" + ) + + with selection.path.open("r", encoding="utf-8") as file: + raw = yaml.safe_load(file) or {} + + if not isinstance(raw, dict): + raise ValueError(f"Agent config must be a YAML mapping: {selection.path}") + + if "agent" in raw: + agent_value = raw["agent"] + if not isinstance(agent_value, dict): + raise ValueError("The 'agent' key in YAML must be a mapping") + raw = agent_value + + resolved = _resolve_env_refs(raw) + overrides = _normalize_agent_overrides(resolved) + missing_required = _missing_required_keys(overrides) + if missing_required: + raise ValueError( + f"Missing required agent settings in YAML ({selection.path}): " + + ", ".join(missing_required) + ) + + overrides["agent_config_path"] = str(selection.path) + overrides["agent_config_source"] = selection.source + return overrides + + +def load_settings( + agent_config_path: Optional[str] = None, + agent_profile: Optional[str] = None, + argv: Optional[List[str]] = None, +) -> "Settings": + """Load settings from .env and optional agent YAML.""" + selection = _resolve_agent_selection( + agent_config_path=agent_config_path, + agent_profile=agent_profile, + argv=argv, + ) + agent_overrides = _load_agent_overrides(selection) + return Settings(**agent_overrides) class Settings(BaseSettings): @@ -37,30 +382,35 @@ class Settings(BaseSettings): vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds") vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds") - # OpenAI / LLM Configuration - openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key") - openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)") + # LLM Configuration + llm_provider: str = Field( + default="openai", + description="LLM provider (openai, openai_compatible, siliconflow)" + ) + llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key") + llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL") llm_model: str = Field(default="gpt-4o-mini", description="LLM model name") llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation") # TTS Configuration tts_provider: str = Field( default="openai_compatible", - description="TTS provider (edge, openai_compatible; siliconflow alias supported)" + description="TTS provider (edge, openai_compatible, siliconflow)" ) + tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key") + tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL") + tts_model: Optional[str] = Field(default=None, description="TTS model name") tts_voice: str = Field(default="anna", description="TTS voice name") tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier") - # SiliconFlow Configuration - siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key") - siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model") - # ASR Configuration asr_provider: str = Field( default="openai_compatible", - description="ASR provider (openai_compatible, buffered; siliconflow alias supported)" + description="ASR provider (openai_compatible, buffered, siliconflow)" ) - siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model") + asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key") + asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL") + asr_model: Optional[str] = Field(default=None, description="ASR model name") asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") asr_start_min_speech_ms: int = Field( @@ -122,6 +472,10 @@ class Settings(BaseSettings): backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds") history_default_user_id: int = Field(default=1, description="Fallback user_id for history records") + # Agent YAML metadata + agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path") + agent_config_source: str = Field(default="none", description="How the agent YAML was selected") + @property def chunk_size_bytes(self) -> int: """Calculate chunk size in bytes based on sample rate and duration.""" @@ -146,7 +500,7 @@ class Settings(BaseSettings): # Global settings instance -settings = Settings() +settings = load_settings() def get_settings() -> Settings: diff --git a/app/main.py b/app/main.py index e2c6c75..bf05cc9 100644 --- a/app/main.py +++ b/app/main.py @@ -357,6 +357,12 @@ async def startup_event(): logger.info(f"Server: {settings.host}:{settings.port}") logger.info(f"Sample rate: {settings.sample_rate} Hz") logger.info(f"VAD model: {settings.vad_model_path}") + if settings.agent_config_path: + logger.info( + f"Agent config loaded ({settings.agent_config_source}): {settings.agent_config_path}" + ) + else: + logger.info("Agent config: none (using .env/default agent values)") @app.on_event("shutdown") diff --git a/config/agents/example.yaml b/config/agents/example.yaml new file mode 100644 index 0000000..114830e --- /dev/null +++ b/config/agents/example.yaml @@ -0,0 +1,50 @@ +# Agent behavior configuration (safe to edit per profile) +# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers). +# Infra/server/network settings should stay in .env. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + # Required: no fallback. You can still reference env explicitly. + api_key: your_llm_api_key + # Optional for OpenAI-compatible endpoints: + api_url: https://api.qnaigc.com/v1 + + tts: + # provider: edge | openai_compatible | siliconflow + provider: openai_compatible + api_key: your_tts_api_key + api_url: https://api.siliconflow.cn/v1/audio/speech + model: FunAudioLLM/CosyVoice2-0.5B + voice: anna + speed: 1.0 + + asr: + # provider: buffered | openai_compatible | siliconflow + provider: openai_compatible + api_key: you_asr_api_key + api_url: https://api.siliconflow.cn/v1/audio/transcriptions + model: FunAudioLLM/SenseVoiceSmall + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational. + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 diff --git a/core/duplex_pipeline.py b/core/duplex_pipeline.py index 2265db4..a148fc6 100644 --- a/core/duplex_pipeline.py +++ b/core/duplex_pipeline.py @@ -310,7 +310,12 @@ class DuplexPipeline: def resolved_runtime_config(self) -> Dict[str, Any]: """Return current effective runtime configuration without secrets.""" - llm_provider = str(self._runtime_llm.get("provider") or "openai").lower() + llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower() + llm_base_url = ( + self._runtime_llm.get("baseUrl") + or settings.llm_api_url + or self._default_llm_base_url(llm_provider) + ) tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower() asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower() output_mode = str(self._runtime_output.get("mode") or "").strip().lower() @@ -323,18 +328,18 @@ class DuplexPipeline: "llm": { "provider": llm_provider, "model": str(self._runtime_llm.get("model") or settings.llm_model), - "baseUrl": self._runtime_llm.get("baseUrl") or settings.openai_api_url, + "baseUrl": llm_base_url, }, "asr": { "provider": asr_provider, - "model": str(self._runtime_asr.get("model") or settings.siliconflow_asr_model), + "model": str(self._runtime_asr.get("model") or settings.asr_model or ""), "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), }, "tts": { "enabled": self._tts_output_enabled(), "provider": tts_provider, - "model": str(self._runtime_tts.get("model") or settings.siliconflow_tts_model), + "model": str(self._runtime_tts.get("model") or settings.tts_model or ""), "voice": str(self._runtime_tts.get("voice") or settings.tts_voice), "speed": float(self._runtime_tts.get("speed") or settings.tts_speed), }, @@ -452,6 +457,18 @@ class DuplexPipeline: normalized = str(provider or "").strip().lower() return normalized in {"openai_compatible", "openai-compatible", "siliconflow"} + @staticmethod + def _is_llm_provider_supported(provider: Any) -> bool: + normalized = str(provider or "").strip().lower() + return normalized in {"openai", "openai_compatible", "openai-compatible", "siliconflow"} + + @staticmethod + def _default_llm_base_url(provider: Any) -> Optional[str]: + normalized = str(provider or "").strip().lower() + if normalized == "siliconflow": + return "https://api.siliconflow.cn/v1" + return None + def _tts_output_enabled(self) -> bool: enabled = self._coerce_bool(self._runtime_tts.get("enabled")) if enabled is not None: @@ -527,12 +544,16 @@ class DuplexPipeline: try: # Connect LLM service if not self.llm_service: - llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key - llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url + llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower() + llm_api_key = self._runtime_llm.get("apiKey") or settings.llm_api_key + llm_base_url = ( + self._runtime_llm.get("baseUrl") + or settings.llm_api_url + or self._default_llm_base_url(llm_provider) + ) llm_model = self._runtime_llm.get("model") or settings.llm_model - llm_provider = (self._runtime_llm.get("provider") or "openai").lower() - if llm_provider == "openai" and llm_api_key: + if self._is_llm_provider_supported(llm_provider) and llm_api_key: self.llm_service = OpenAILLMService( api_key=llm_api_key, base_url=llm_base_url, @@ -540,7 +561,7 @@ class DuplexPipeline: knowledge_config=self._resolved_knowledge_config(), ) else: - logger.warning("No OpenAI API key - using mock LLM") + logger.warning("LLM provider unsupported or API key missing - using mock LLM") self.llm_service = MockLLMService() if hasattr(self.llm_service, "set_knowledge_config"): @@ -556,20 +577,22 @@ class DuplexPipeline: if tts_output_enabled: if not self.tts_service: tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() - tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key + tts_api_key = self._runtime_tts.get("apiKey") or settings.tts_api_key + tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url tts_voice = self._runtime_tts.get("voice") or settings.tts_voice - tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model + tts_model = self._runtime_tts.get("model") or settings.tts_model tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) if self._is_openai_compatible_provider(tts_provider) and tts_api_key: self.tts_service = OpenAICompatibleTTSService( api_key=tts_api_key, + api_url=tts_api_url, voice=tts_voice, - model=tts_model, + model=tts_model or "FunAudioLLM/CosyVoice2-0.5B", sample_rate=settings.sample_rate, speed=tts_speed ) - logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)") + logger.info(f"Using OpenAI-compatible TTS service (provider={tts_provider})") else: self.tts_service = EdgeTTSService( voice=tts_voice, @@ -592,21 +615,23 @@ class DuplexPipeline: # Connect ASR service if not self.asr_service: asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower() - asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key - asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model + asr_api_key = self._runtime_asr.get("apiKey") or settings.asr_api_key + asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url + asr_model = self._runtime_asr.get("model") or settings.asr_model asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) if self._is_openai_compatible_provider(asr_provider) and asr_api_key: self.asr_service = OpenAICompatibleASRService( api_key=asr_api_key, - model=asr_model, + api_url=asr_api_url, + model=asr_model or "FunAudioLLM/SenseVoiceSmall", sample_rate=settings.sample_rate, interim_interval_ms=asr_interim_interval, min_audio_for_interim_ms=asr_min_audio_ms, on_transcript=self._on_transcript_callback ) - logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)") + logger.info(f"Using OpenAI-compatible ASR service (provider={asr_provider})") else: self.asr_service = BufferedASRService( sample_rate=settings.sample_rate diff --git a/docs/ws_v1_schema_zh.md b/docs/ws_v1_schema_zh.md new file mode 100644 index 0000000..25c5ad9 --- /dev/null +++ b/docs/ws_v1_schema_zh.md @@ -0,0 +1,520 @@ +# WS v1 协议完整说明(中文) + +本文档描述 `/ws` 端点的 WebSocket v1 协议,覆盖: +- 客户端输入(JSON 文本消息 + 二进制音频); +- 服务端输出(JSON 事件 + 二进制音频); +- 每个参数的类型、约束、含义与使用方式; +- 握手顺序、状态机、错误语义与实现细节。 + +实现对照来源: +- `models/ws_v1.py` +- `core/session.py` +- `core/duplex_pipeline.py` +- `app/main.py` + +--- + +## 1. 传输与基础规则 + +- 连接地址:`ws:///ws` +- 单连接双通道承载: + - 文本帧:JSON 控制消息(严格校验 schema) + - 二进制帧:原始 PCM 音频 +- JSON 校验策略: + - 所有已定义客户端消息都 `extra="forbid"`,即不允许未声明字段; + - `hello.version` 固定必须是 `"v1"`; + - 缺失 `type` 或未知 `type` 会返回协议错误。 + +--- + +## 2. 状态机与消息顺序 + +### 2.1 服务端状态 + +- `WAIT_HELLO`:等待 `hello` +- `WAIT_START`:已通过握手,等待 `session.start` +- `ACTIVE`:会话运行中,可收发文本/音频 +- `STOPPED`:会话结束 + +### 2.2 正确顺序 + +1. 客户端发送 `hello` +2. 服务端返回 `hello.ack` +3. 客户端发送 `session.start` +4. 服务端返回 `session.started` +5. 客户端可持续发送: + - 二进制音频 + - `input.text`(可选) + - `response.cancel`(可选) + - `tool_call.results`(可选) +6. 客户端发送 `session.stop` 或直接断开连接 + +顺序错误会返回 `error`,`code = "protocol.order"`。 + +--- + +## 3. 客户端 -> 服务端消息(输入) + +## 3.1 `hello` + +示例: + +```json +{ + "type": "hello", + "version": "v1", + "auth": { + "apiKey": "optional-api-key", + "jwt": "optional-jwt" + } +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"hello"` | 消息类型 | 握手第一条消息 | +| `version` | string | 是 | 固定 `"v1"` | 协议版本 | 版本不匹配会 `protocol.version_unsupported` 并断开 | +| `auth` | object \| null | 否 | 仅允许 `apiKey`、`jwt` | 认证载荷 | 认证策略由服务端配置决定 | +| `auth.apiKey` | string \| null | 否 | 任意字符串 | API Key | 若服务端配置 `WS_API_KEY`,必须精确匹配 | +| `auth.jwt` | string \| null | 否 | 任意字符串 | JWT 字符串 | 当 `WS_REQUIRE_AUTH=true` 时可用于满足“有认证信息”条件 | + +认证行为: +- 若设置了 `WS_API_KEY`:必须提供且匹配 `auth.apiKey`,否则 `auth.invalid_api_key` 并关闭连接。 +- 若 `WS_REQUIRE_AUTH=true` 且未设置 `WS_API_KEY`:`auth.apiKey` 或 `auth.jwt` 至少一个非空,否则 `auth.required` 并关闭连接。 + +## 3.2 `session.start` + +示例: + +```json +{ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": 16000, + "channels": 1 + }, + "metadata": { + "appId": "assistant_123", + "channel": "web", + "configVersionId": "cfg_20260217_01", + "client": "web-debug", + "output": { + "mode": "audio" + }, + "systemPrompt": "你是简洁助手", + "greeting": "你好,我能帮你什么?" + } +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"session.start"` | 启动会话 | 握手后第二阶段消息 | +| `audio` | object \| null | 否 | 仅支持固定值 | 音频格式描述 | 仅用于声明;MVP 实际只接受固定 PCM | +| `audio.encoding` | string | 否 | 固定 `"pcm_s16le"` | 编码格式 | 非该值会在模型校验层报错 | +| `audio.sample_rate_hz` | number | 否 | 固定 `16000` | 采样率 | 16kHz | +| `audio.channels` | number | 否 | 固定 `1` | 声道数 | 单声道 | +| `metadata` | object \| null | 否 | 任意对象(会被白名单过滤) | 运行时配置 | 用于 app/channel/提示词/输出模式等覆盖 | + +`metadata` 白名单策略(关键): +- 允许透传的标识字段(ID 类): + - `appId` / `app_id` + - `channel` + - `configVersionId` / `config_version_id` +- 允许透传的覆盖字段: + - `firstTurnMode` + - `greeting` + - `generatedOpenerEnabled` + - `systemPrompt` + - `output` + - `bargeIn` + - `knowledge` + - `knowledgeBaseId` + - `history` + - `userId` + - `assistantId` + - `source` +- 客户端传入 `metadata.services` 会被忽略(服务端会记录 warning),服务配置由后端/环境变量决定。 + +`output.mode` 用法: +- `"audio"`(默认语音输出) +- `"text"`(纯文本输出) + - 纯文本模式下仍会收到 `assistant.response.delta/final`; + - 不会收到 TTS 音频帧与 `output.audio.start/end`。 + +## 3.3 `input.text` + +示例: + +```json +{ + "type": "input.text", + "text": "你能做什么?" +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"input.text"` | 文本输入 | 跳过 ASR,直接触发 LLM 回答 | +| `text` | string | 是 | 非空字符串为佳 | 用户文本 | 用于文本聊天或调试 | + +## 3.4 `response.cancel` + +示例: + +```json +{ + "type": "response.cancel", + "graceful": false +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 默认值 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | - | 固定 `"response.cancel"` | 请求中断当前回答 | +| `graceful` | boolean | 否 | `false` | 取消方式 | `false` 立即打断;`true` 当前实现主要用于记录日志,不强制中断 | + +## 3.5 `tool_call.results` + +仅在工具执行端为客户端时使用(`assistant.tool_call.executor == "client"`)。 + +示例: + +```json +{ + "type": "tool_call.results", + "results": [ + { + "tool_call_id": "call_abc123", + "name": "weather", + "output": { "temp_c": 21, "condition": "sunny" }, + "status": { "code": 200, "message": "ok" } + } + ] +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"tool_call.results"` | 工具执行回传 | 客户端工具结果上送 | +| `results` | array | 否 | 默认为空数组 | 多个工具结果 | 可批量回传 | +| `results[].tool_call_id` | string | 是 | 任意字符串 | 工具调用ID | 必须与 `assistant.tool_call.tool_call_id` 对应 | +| `results[].name` | string | 是 | 任意字符串 | 工具名 | 建议与请求一致 | +| `results[].output` | any | 否 | 任意 JSON | 工具输出 | 供模型后续组织回答 | +| `results[].status` | object | 是 | 包含 `code`、`message` | 执行状态 | 用于判定成功/失败 | +| `results[].status.code` | number | 是 | HTTP 风格状态码 | 状态码 | `200-299` 判定成功 | +| `results[].status.message` | string | 是 | 任意字符串 | 状态描述 | 例如 `"ok"` / `"timeout"` | + +处理规则: +- 未请求过的 `tool_call_id` 会被忽略(防止伪造/串话); +- 重复回传会被忽略; +- 超时未回传会由服务端合成超时结果(`504`)。 + +## 3.6 `session.stop` + +示例: + +```json +{ + "type": "session.stop", + "reason": "client_disconnect" +} +``` + +字段说明: + +| 字段 | 类型 | 必填 | 约束 | 含义 | 使用说明 | +|---|---|---|---|---|---| +| `type` | string | 是 | 固定 `"session.stop"` | 结束会话 | 正常结束推荐发送 | +| `reason` | string \| null | 否 | 任意字符串 | 结束原因 | 服务端会回传到 `session.stopped.reason` | + +--- + +## 4. 二进制音频输入(客户端 -> 服务端) + +在 `session.started` 之后可持续发送二进制音频。 + +固定格式(MVP): +- 编码:`pcm_s16le` +- 采样率:`16000` +- 声道:`1` +- 帧长:20ms = `640 bytes` + +分包规则: +- 单个 WebSocket 二进制消息可包含 1 帧或多帧; +- 长度必须是 `640` 的整数倍; +- 不是 `640` 倍数会触发 `audio.frame_size_mismatch`,该消息整包丢弃; +- 奇数字节长度会触发 `audio.invalid_pcm`。 + +--- + +## 5. 服务端 -> 客户端事件(输出) + +所有 JSON 事件都包含统一包络字段。 + +## 5.1 统一包络(Envelope) + +```json +{ + "type": "event.name", + "timestamp": 1730000000000, + "sessionId": "sess_xxx", + "seq": 42, + "source": "asr", + "trackId": "audio_in", + "data": {} +} +``` + +字段说明: + +| 字段 | 类型 | 含义 | 使用说明 | +|---|---|---|---| +| `type` | string | 事件类型 | 见下方事件清单 | +| `timestamp` | number | 事件时间戳(毫秒) | 由 `ev()` 生成 | +| `sessionId` | string | 会话ID | 同一连接固定 | +| `seq` | number | 单会话递增序号 | 可用于重放、去重、排序 | +| `source` | string | 事件来源 | 常见:`asr`/`llm`/`tts`/`tool`/`system`/`client`/`server` | +| `trackId` | string | 事件轨道 | 常用:`audio_in`/`audio_out`/`control` | +| `data` | object | 结构化数据 | 顶层业务字段会镜像进 `data` 以兼容旧客户端 | + +关联ID(在 `data` 内自动注入,存在时): +- `turn_id`:一次用户-助手对话轮次 +- `utterance_id`:一次用户语音话语 +- `response_id`:一次助手生成响应 +- `tool_call_id`:一次工具调用 +- `tts_id`:一次 TTS 播放段 + +## 5.2 事件类型与参数 + +### 5.2.1 会话与控制类 + +1. `hello.ack` +- 关键字段:`version` +- 含义:握手成功,应紧接着发送 `session.start` + +2. `session.started` +- 关键字段: + - `trackId` + - `tracks.audio_in` + - `tracks.audio_out` + - `tracks.control` + - `audio`(回显客户端声明的音频元信息) +- 含义:会话进入 ACTIVE,可发音频/文本 + +3. `config.resolved` +- 关键字段: + - `config.appId` + - `config.channel` + - `config.configVersionId` + - `config.prompt.sha256` + - `config.output` + - `config.services`(去密钥后的有效服务配置) + - `config.tools.allowlist` + - `config.tracks` +- 含义:服务端最终生效配置快照,便于前端展示与排错 + +4. `heartbeat` +- 关键字段:无业务字段(仅 envelope) +- 含义:保活心跳 +- 默认间隔:`heartbeat_interval_sec`(默认 50s) + +5. `session.stopped` +- 关键字段:`reason` +- 含义:会话结束确认 + +6. `error` +- 关键字段: + - `sender` + - `code` + - `message` + - `stage` + - `retryable` + - `trackId` + - `data.error`(结构化错误镜像) +- 含义:统一错误事件 + +### 5.2.2 识别与输入侧(ASR/VAD) + +1. `input.speech_started` +- 字段:`probability` +- 含义:检测到语音开始 + +2. `input.speech_stopped` +- 字段:`probability` +- 含义:检测到语音结束 + +3. `transcript.delta` +- 字段:`text` +- 含义:ASR 增量识别文本(节流发送) + +4. `transcript.final` +- 字段:`text` +- 含义:ASR 最终识别文本 + +### 5.2.3 输出侧(LLM/TTS/Tool) + +1. `assistant.response.delta` +- 字段:`text` +- 含义:助手增量文本输出(节流发送) + +2. `assistant.response.final` +- 字段:`text` +- 含义:助手完整文本输出 + +3. `assistant.tool_call` +- 字段: + - `tool_call_id` + - `tool_name` + - `arguments`(对象) + - `executor`(`client` 或 `server`) + - `timeout_ms` + - `tool_call`(完整工具调用对象) +- 含义:通知客户端发生工具调用(用于可视化或客户端执行) + +4. `assistant.tool_result` +- 字段: + - `source`(`client` 或 `server`) + - `tool_call_id` + - `tool_name` + - `ok`(boolean) + - `error`(失败时 `{code,message,retryable}`) + - `result`(原始结果对象) +- 含义:工具调用结果回执 + +5. `output.audio.start` +- 含义:TTS 音频输出开始边界 + +6. `output.audio.end` +- 含义:TTS 音频输出结束边界 + +7. `response.interrupted` +- 含义:当前回答被打断(barge-in 或 cancel) + +8. `metrics.ttfb` +- 字段:`latencyMs` +- 含义:首包音频时延(TTFB) + +### 5.2.4 工作流扩展事件(可选) + +若 `metadata.workflow` 生效,会额外出现: +- `workflow.started` +- `workflow.node.entered` +- `workflow.edge.taken` +- `workflow.tool.requested` +- `workflow.human_transfer` +- `workflow.ended` + +这些事件用于外部可视化工作流状态,不影响基础语音会话协议。 + +--- + +## 6. 服务端二进制音频输出(服务端 -> 客户端) + +- 音频为 PCM 二进制帧; +- 发送单位对齐到 `640 bytes`(不足会补零后发送); +- 前端通常结合 `output.audio.start/end` 做播放边界控制; +- 收到 `response.interrupted` 后应丢弃队列中未播放完的旧音频。 + +--- + +## 7. 错误模型与常见错误码 + +统一结构(`error` 事件): + +```json +{ + "type": "error", + "sender": "client", + "code": "protocol.invalid_message", + "message": "Invalid message: ...", + "stage": "protocol", + "retryable": false, + "trackId": "control", + "data": { + "error": { + "stage": "protocol", + "code": "protocol.invalid_message", + "message": "Invalid message: ...", + "retryable": false + } + } +} +``` + +字段语义: +- `sender`:错误来源角色(如 `client` / `server` / `auth`) +- `code`:机器可读错误码 +- `message`:人类可读描述 +- `stage`:阶段(`protocol|audio|asr|llm|tts|tool`) +- `retryable`:是否建议重试 +- `trackId`:错误归属轨道 + +常见错误码: +- `protocol.invalid_json` +- `protocol.invalid_message` +- `protocol.order` +- `protocol.version_unsupported` +- `protocol.unsupported` +- `auth.invalid_api_key` +- `auth.required` +- `audio.invalid_pcm` +- `audio.frame_size_mismatch` +- `audio.processing_failed` +- `server.internal` + +--- + +## 8. 心跳与超时 + +服务端后台任务逻辑: +- 每隔约 5 秒检查一次连接; +- 超过 `inactivity_timeout_sec`(默认 60 秒)未收到任何客户端消息则关闭会话; +- 每隔 `heartbeat_interval_sec`(默认 50 秒)发送一次 `heartbeat`。 + +客户端建议: +- 持续上行音频或定期发送轻量文本消息,避免被判定闲置; +- 用 `heartbeat` + `seq` 检测连接活性和事件乱序。 + +--- + +## 9. 实战接入建议 + +1. 建连后立即发送 `hello`,收到 `hello.ack` 后再发 `session.start`。 +2. 语音输入严格按 16k/16bit/mono,并保证每个 WS 二进制消息长度是 `640*n`。 +3. UI 层把 `assistant.response.delta` 当作流式显示,把 `assistant.response.final` 当作收敛结果。 +4. 播放器用 `output.audio.start/end` 管理一轮播报生命周期。 +5. 工具调用场景下,若 `executor=client`,务必按 `tool_call_id` 回传 `tool_call.results`。 +6. 出现 `error` 时优先按 `code` 分流处理,而不是仅看 `message`。 + +--- + +## 10. 最小完整时序示例 + +```text +Client -> hello +Server <- hello.ack +Client -> session.start +Server <- session.started +Server <- config.resolved +Client -> (binary pcm frames...) +Server <- input.speech_started / transcript.delta / transcript.final +Server <- assistant.response.delta / assistant.response.final +Server <- output.audio.start +Server <- (binary pcm frames...) +Server <- output.audio.end +Client -> session.stop +Server <- session.stopped +``` + diff --git a/requirements.txt b/requirements.txt index 3d38414..d117414 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ pydantic>=2.5.3 pydantic-settings>=2.1.0 python-dotenv>=1.0.0 toml>=0.10.2 +pyyaml>=6.0.1 # Logging loguru>=0.7.2 diff --git a/services/llm.py b/services/llm.py index a25ff26..51bfbe4 100644 --- a/services/llm.py +++ b/services/llm.py @@ -43,14 +43,14 @@ class OpenAILLMService(BaseLLMService): Args: model: Model name (e.g., "gpt-4o-mini", "gpt-4o") - api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars) base_url: Custom API base URL (for Azure or compatible APIs) system_prompt: Default system prompt for conversations """ super().__init__(model=model) - self.api_key = api_key or os.getenv("OPENAI_API_KEY") - self.base_url = base_url or os.getenv("OPENAI_API_URL") + self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL") self.system_prompt = system_prompt or ( "You are a helpful, friendly voice assistant. " "Keep your responses concise and conversational. " diff --git a/services/openai_compatible_asr.py b/services/openai_compatible_asr.py index daf7c04..bcf0fae 100644 --- a/services/openai_compatible_asr.py +++ b/services/openai_compatible_asr.py @@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti import asyncio import io +import os import wave from typing import AsyncIterator, Optional, Callable, Awaitable from loguru import logger @@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService): def __init__( self, - api_key: str, + api_key: Optional[str] = None, + api_url: Optional[str] = None, model: str = "FunAudioLLM/SenseVoiceSmall", sample_rate: int = 16000, language: str = "auto", @@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService): Args: api_key: Provider API key + api_url: Provider API URL (defaults to SiliconFlow endpoint) model: ASR model name or alias sample_rate: Audio sample rate (16000 recommended) language: Language code (auto for automatic detection) @@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService): if not AIOHTTP_AVAILABLE: raise RuntimeError("aiohttp is required for OpenAICompatibleASRService") - self.api_key = api_key + self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY") + self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL self.model = self.MODELS.get(model.lower(), model) self.interim_interval_ms = interim_interval_ms self.min_audio_for_interim_ms = min_audio_for_interim_ms @@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService): async def connect(self) -> None: """Connect to the service.""" + if not self.api_key: + raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.") self._session = aiohttp.ClientSession( headers={ "Authorization": f"Bearer {self.api_key}" @@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService): ) form_data.add_field('model', self.model) - async with self._session.post(self.API_URL, data=form_data) as response: + async with self._session.post(self.api_url, data=form_data) as response: if response.status == 200: result = await response.json() text = result.get("text", "").strip() diff --git a/services/openai_compatible_tts.py b/services/openai_compatible_tts.py index 4967557..1abb1e5 100644 --- a/services/openai_compatible_tts.py +++ b/services/openai_compatible_tts.py @@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService): def __init__( self, api_key: Optional[str] = None, + api_url: Optional[str] = None, voice: str = "anna", model: str = "FunAudioLLM/CosyVoice2-0.5B", sample_rate: int = 16000, @@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService): Initialize OpenAI-compatible TTS service. Args: - api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var) + api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars) + api_url: Provider API URL (defaults to SiliconFlow endpoint) voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) model: Model name sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) @@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService): super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed) - self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY") + self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY") self.model = model - self.api_url = "https://api.siliconflow.cn/v1/audio/speech" + self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech" self._session: Optional[aiohttp.ClientSession] = None self._cancel_event = asyncio.Event() @@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService): async def connect(self) -> None: """Initialize HTTP session.""" if not self.api_key: - raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.") + raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.") self._session = aiohttp.ClientSession( headers={ diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py new file mode 100644 index 0000000..86fa0d4 --- /dev/null +++ b/tests/test_agent_config.py @@ -0,0 +1,204 @@ +import os +from pathlib import Path + +import pytest + +os.environ.setdefault("LLM_API_KEY", "test-openai-key") +os.environ.setdefault("TTS_API_KEY", "test-tts-key") +os.environ.setdefault("ASR_API_KEY", "test-asr-key") + +from app.config import load_settings + + +def _write_yaml(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + + +def _full_agent_yaml(llm_model: str = "gpt-4o-mini", llm_key: str = "test-openai-key") -> str: + return f""" +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.63 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + provider: openai_compatible + model: {llm_model} + temperature: 0.2 + api_key: {llm_key} + api_url: https://example-llm.invalid/v1 + + tts: + provider: openai_compatible + api_key: test-tts-key + api_url: https://example-tts.invalid/v1/audio/speech + model: FunAudioLLM/CosyVoice2-0.5B + voice: anna + speed: 1.0 + + asr: + provider: openai_compatible + api_key: test-asr-key + api_url: https://example-asr.invalid/v1/audio/transcriptions + model: FunAudioLLM/SenseVoiceSmall + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: You are a strict test assistant. + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 +""".strip() + + +def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + config_dir = tmp_path / "config" / "agents" + _write_yaml( + config_dir / "support.yaml", + _full_agent_yaml(llm_model="gpt-4.1-mini"), + ) + + settings = load_settings( + argv=["--agent-profile", "support"], + ) + + assert settings.llm_model == "gpt-4.1-mini" + assert settings.llm_temperature == 0.2 + assert settings.vad_threshold == 0.63 + assert settings.agent_config_source == "cli_profile" + assert settings.agent_config_path == str((config_dir / "support.yaml").resolve()) + + +def test_cli_path_has_higher_priority_than_env(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + env_file = tmp_path / "config" / "agents" / "env.yaml" + cli_file = tmp_path / "config" / "agents" / "cli.yaml" + + _write_yaml(env_file, _full_agent_yaml(llm_model="env-model")) + _write_yaml(cli_file, _full_agent_yaml(llm_model="cli-model")) + + monkeypatch.setenv("AGENT_CONFIG_PATH", str(env_file)) + + settings = load_settings(argv=["--agent-config", str(cli_file)]) + + assert settings.llm_model == "cli-model" + assert settings.agent_config_source == "cli_path" + assert settings.agent_config_path == str(cli_file.resolve()) + + +def test_default_yaml_is_loaded_without_args_or_env(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + default_file = tmp_path / "config" / "agents" / "default.yaml" + _write_yaml(default_file, _full_agent_yaml(llm_model="from-default")) + + monkeypatch.delenv("AGENT_CONFIG_PATH", raising=False) + monkeypatch.delenv("AGENT_PROFILE", raising=False) + + settings = load_settings(argv=[]) + + assert settings.llm_model == "from-default" + assert settings.agent_config_source == "default" + assert settings.agent_config_path == str(default_file.resolve()) + + +def test_missing_required_agent_settings_fail(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "missing-required.yaml" + _write_yaml( + file_path, + """ +agent: + llm: + model: gpt-4o-mini +""".strip(), + ) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_blank_required_provider_key_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "blank-key.yaml" + _write_yaml(file_path, _full_agent_yaml(llm_key="")) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_missing_tts_api_url_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "missing-tts-url.yaml" + _write_yaml( + file_path, + _full_agent_yaml().replace( + " api_url: https://example-tts.invalid/v1/audio/speech\n", + "", + ), + ) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_missing_asr_api_url_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "missing-asr-url.yaml" + _write_yaml( + file_path, + _full_agent_yaml().replace( + " api_url: https://example-asr.invalid/v1/audio/transcriptions\n", + "", + ), + ) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_agent_yaml_unknown_key_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "bad-agent.yaml" + _write_yaml(file_path, _full_agent_yaml() + "\n unknown_option: true") + + with pytest.raises(ValueError, match="Unknown agent config keys"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_legacy_siliconflow_section_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "legacy-siliconflow.yaml" + _write_yaml( + file_path, + """ +agent: + siliconflow: + api_key: x +""".strip(), + ) + + with pytest.raises(ValueError, match="Section 'siliconflow' is no longer supported"): + load_settings(argv=["--agent-config", str(file_path)]) + + +def test_agent_yaml_missing_env_reference_fails(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "bad-ref.yaml" + _write_yaml( + file_path, + _full_agent_yaml(llm_key="${UNSET_LLM_API_KEY}"), + ) + + with pytest.raises(ValueError, match="Missing environment variable"): + load_settings(argv=["--agent-config", str(file_path)])