Organize config

This commit is contained in:
Xin Wang
2026-02-25 15:52:55 +08:00
parent 2b2193557d
commit 8b9064f6e6
12 changed files with 1248 additions and 92 deletions

View File

@@ -1,9 +1,354 @@
"""Configuration management using Pydantic settings."""
"""Configuration management using Pydantic settings and agent YAML profiles."""
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
import json
try:
import yaml
except ImportError: # pragma: no cover - validated when agent YAML is used
yaml = None
_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}")
_DEFAULT_AGENT_CONFIG_DIR = "config/agents"
_DEFAULT_AGENT_CONFIG_FILE = "default.yaml"
_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
"vad": {
"type": "vad_type",
"model_path": "vad_model_path",
"threshold": "vad_threshold",
"min_speech_duration_ms": "vad_min_speech_duration_ms",
"eou_threshold_ms": "vad_eou_threshold_ms",
},
"llm": {
"provider": "llm_provider",
"model": "llm_model",
"temperature": "llm_temperature",
"api_key": "llm_api_key",
"api_url": "llm_api_url",
},
"tts": {
"provider": "tts_provider",
"api_key": "tts_api_key",
"api_url": "tts_api_url",
"model": "tts_model",
"voice": "tts_voice",
"speed": "tts_speed",
},
"asr": {
"provider": "asr_provider",
"api_key": "asr_api_key",
"api_url": "asr_api_url",
"model": "asr_model",
"interim_interval_ms": "asr_interim_interval_ms",
"min_audio_ms": "asr_min_audio_ms",
"start_min_speech_ms": "asr_start_min_speech_ms",
"pre_speech_ms": "asr_pre_speech_ms",
"final_tail_ms": "asr_final_tail_ms",
},
"duplex": {
"enabled": "duplex_enabled",
"greeting": "duplex_greeting",
"system_prompt": "duplex_system_prompt",
},
"barge_in": {
"min_duration_ms": "barge_in_min_duration_ms",
"silence_tolerance_ms": "barge_in_silence_tolerance_ms",
},
}
_AGENT_SETTING_KEYS = {
"vad_type",
"vad_model_path",
"vad_threshold",
"vad_min_speech_duration_ms",
"vad_eou_threshold_ms",
"llm_provider",
"llm_api_key",
"llm_api_url",
"llm_model",
"llm_temperature",
"tts_provider",
"tts_api_key",
"tts_api_url",
"tts_model",
"tts_voice",
"tts_speed",
"asr_provider",
"asr_api_key",
"asr_api_url",
"asr_model",
"asr_interim_interval_ms",
"asr_min_audio_ms",
"asr_start_min_speech_ms",
"asr_pre_speech_ms",
"asr_final_tail_ms",
"duplex_enabled",
"duplex_greeting",
"duplex_system_prompt",
"barge_in_min_duration_ms",
"barge_in_silence_tolerance_ms",
}
_BASE_REQUIRED_AGENT_SETTING_KEYS = {
"vad_type",
"vad_model_path",
"vad_threshold",
"vad_min_speech_duration_ms",
"vad_eou_threshold_ms",
"llm_provider",
"llm_model",
"llm_temperature",
"tts_provider",
"tts_voice",
"tts_speed",
"asr_provider",
"asr_interim_interval_ms",
"asr_min_audio_ms",
"asr_start_min_speech_ms",
"asr_pre_speech_ms",
"asr_final_tail_ms",
"duplex_enabled",
"duplex_system_prompt",
"barge_in_min_duration_ms",
"barge_in_silence_tolerance_ms",
}
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
return str(overrides.get(key) or default).strip().lower()
def _is_blank(value: Any) -> bool:
return value is None or (isinstance(value, str) and not value.strip())
@dataclass(frozen=True)
class AgentConfigSelection:
"""Resolved agent config location and how it was selected."""
path: Optional[Path]
source: str
def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]:
"""Parse only agent-related CLI flags from argv."""
config_path: Optional[str] = None
profile: Optional[str] = None
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith("--agent-config="):
config_path = arg.split("=", 1)[1].strip() or None
elif arg == "--agent-config" and i + 1 < len(argv):
config_path = argv[i + 1].strip() or None
i += 1
elif arg.startswith("--agent-profile="):
profile = arg.split("=", 1)[1].strip() or None
elif arg == "--agent-profile" and i + 1 < len(argv):
profile = argv[i + 1].strip() or None
i += 1
i += 1
return config_path, profile
def _agent_config_dir() -> Path:
base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR))
if not base_dir.is_absolute():
base_dir = Path.cwd() / base_dir
return base_dir.resolve()
def _resolve_agent_selection(
agent_config_path: Optional[str] = None,
agent_profile: Optional[str] = None,
argv: Optional[List[str]] = None,
) -> AgentConfigSelection:
cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:]))
path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH")
profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE")
source = "none"
candidate: Optional[Path] = None
if path_value:
source = "cli_path" if (agent_config_path or cli_path) else "env_path"
candidate = Path(path_value)
elif profile_value:
source = "cli_profile" if (agent_profile or cli_profile) else "env_profile"
candidate = _agent_config_dir() / f"{profile_value}.yaml"
else:
fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE
if fallback.exists():
source = "default"
candidate = fallback
if candidate is None:
raise ValueError(
"Agent YAML config is required. Provide --agent-config/--agent-profile "
"or create config/agents/default.yaml."
)
if not candidate.is_absolute():
candidate = (Path.cwd() / candidate).resolve()
else:
candidate = candidate.resolve()
if not candidate.exists():
raise ValueError(f"Agent config file not found ({source}): {candidate}")
if not candidate.is_file():
raise ValueError(f"Agent config path is not a file: {candidate}")
return AgentConfigSelection(path=candidate, source=source)
def _resolve_env_refs(value: Any) -> Any:
"""Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively."""
if isinstance(value, dict):
return {k: _resolve_env_refs(v) for k, v in value.items()}
if isinstance(value, list):
return [_resolve_env_refs(item) for item in value]
if not isinstance(value, str) or "${" not in value:
return value
def _replace(match: re.Match[str]) -> str:
env_key = match.group(1)
default_value = match.group(2)
env_value = os.getenv(env_key)
if env_value is None:
if default_value is None:
raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}")
return default_value
return env_value
return _ENV_REF_PATTERN.sub(_replace, value)
def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize YAML into flat Settings fields."""
normalized: Dict[str, Any] = {}
for key, value in raw.items():
if key == "siliconflow":
raise ValueError(
"Section 'siliconflow' is no longer supported. "
"Move provider-specific fields into agent.llm / agent.asr / agent.tts."
)
section_map = _AGENT_SECTION_KEY_MAP.get(key)
if section_map is None:
normalized[key] = value
continue
if not isinstance(value, dict):
raise ValueError(f"Agent config section '{key}' must be a mapping")
for nested_key, nested_value in value.items():
mapped_key = section_map.get(nested_key)
if mapped_key is None:
raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'")
normalized[mapped_key] = nested_value
unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS)
if unknown_keys:
raise ValueError(
"Unknown agent config keys in YAML: "
+ ", ".join(unknown_keys)
)
return normalized
def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides))
string_required = {
"vad_type",
"vad_model_path",
"llm_provider",
"llm_model",
"tts_provider",
"tts_voice",
"asr_provider",
"duplex_system_prompt",
}
for key in string_required:
if key in overrides and _is_blank(overrides.get(key)):
missing.add(key)
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
missing.add("llm_api_key")
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
missing.add("tts_api_key")
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
missing.add("tts_api_url")
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
missing.add("tts_model")
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
missing.add("asr_api_key")
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
missing.add("asr_api_url")
if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")):
missing.add("asr_model")
return sorted(missing)
def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]:
if yaml is None:
raise RuntimeError(
"PyYAML is required for agent YAML configuration. Install with: pip install pyyaml"
)
with selection.path.open("r", encoding="utf-8") as file:
raw = yaml.safe_load(file) or {}
if not isinstance(raw, dict):
raise ValueError(f"Agent config must be a YAML mapping: {selection.path}")
if "agent" in raw:
agent_value = raw["agent"]
if not isinstance(agent_value, dict):
raise ValueError("The 'agent' key in YAML must be a mapping")
raw = agent_value
resolved = _resolve_env_refs(raw)
overrides = _normalize_agent_overrides(resolved)
missing_required = _missing_required_keys(overrides)
if missing_required:
raise ValueError(
f"Missing required agent settings in YAML ({selection.path}): "
+ ", ".join(missing_required)
)
overrides["agent_config_path"] = str(selection.path)
overrides["agent_config_source"] = selection.source
return overrides
def load_settings(
agent_config_path: Optional[str] = None,
agent_profile: Optional[str] = None,
argv: Optional[List[str]] = None,
) -> "Settings":
"""Load settings from .env and optional agent YAML."""
selection = _resolve_agent_selection(
agent_config_path=agent_config_path,
agent_profile=agent_profile,
argv=argv,
)
agent_overrides = _load_agent_overrides(selection)
return Settings(**agent_overrides)
class Settings(BaseSettings):
@@ -37,30 +382,35 @@ class Settings(BaseSettings):
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
# OpenAI / LLM Configuration
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
# LLM Configuration
llm_provider: str = Field(
default="openai",
description="LLM provider (openai, openai_compatible, siliconflow)"
)
llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key")
llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL")
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
# TTS Configuration
tts_provider: str = Field(
default="openai_compatible",
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
description="TTS provider (edge, openai_compatible, siliconflow)"
)
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
tts_model: Optional[str] = Field(default=None, description="TTS model name")
tts_voice: str = Field(default="anna", description="TTS voice name")
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
# SiliconFlow Configuration
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
# ASR Configuration
asr_provider: str = Field(
default="openai_compatible",
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
description="ASR provider (openai_compatible, buffered, siliconflow)"
)
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key")
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
asr_start_min_speech_ms: int = Field(
@@ -122,6 +472,10 @@ class Settings(BaseSettings):
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
# Agent YAML metadata
agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path")
agent_config_source: str = Field(default="none", description="How the agent YAML was selected")
@property
def chunk_size_bytes(self) -> int:
"""Calculate chunk size in bytes based on sample rate and duration."""
@@ -146,7 +500,7 @@ class Settings(BaseSettings):
# Global settings instance
settings = Settings()
settings = load_settings()
def get_settings() -> Settings: