531 lines
20 KiB
Python
531 lines
20 KiB
Python
"""Configuration management using Pydantic settings and agent YAML profiles."""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from pydantic import Field
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
try:
|
|
import yaml
|
|
except ImportError: # pragma: no cover - validated when agent YAML is used
|
|
yaml = None
|
|
|
|
|
|
_ENV_REF_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::([^}]*))?\}")
|
|
_DEFAULT_AGENT_CONFIG_DIR = "config/agents"
|
|
_DEFAULT_AGENT_CONFIG_FILE = "default.yaml"
|
|
_AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
|
|
"vad": {
|
|
"type": "vad_type",
|
|
"model_path": "vad_model_path",
|
|
"threshold": "vad_threshold",
|
|
"min_speech_duration_ms": "vad_min_speech_duration_ms",
|
|
"eou_threshold_ms": "vad_eou_threshold_ms",
|
|
},
|
|
"llm": {
|
|
"provider": "llm_provider",
|
|
"model": "llm_model",
|
|
"temperature": "llm_temperature",
|
|
"api_key": "llm_api_key",
|
|
"api_url": "llm_api_url",
|
|
},
|
|
"tts": {
|
|
"provider": "tts_provider",
|
|
"api_key": "tts_api_key",
|
|
"api_url": "tts_api_url",
|
|
"model": "tts_model",
|
|
"voice": "tts_voice",
|
|
"speed": "tts_speed",
|
|
},
|
|
"asr": {
|
|
"provider": "asr_provider",
|
|
"api_key": "asr_api_key",
|
|
"api_url": "asr_api_url",
|
|
"model": "asr_model",
|
|
"interim_interval_ms": "asr_interim_interval_ms",
|
|
"min_audio_ms": "asr_min_audio_ms",
|
|
"start_min_speech_ms": "asr_start_min_speech_ms",
|
|
"pre_speech_ms": "asr_pre_speech_ms",
|
|
"final_tail_ms": "asr_final_tail_ms",
|
|
},
|
|
"duplex": {
|
|
"enabled": "duplex_enabled",
|
|
"greeting": "duplex_greeting",
|
|
"system_prompt": "duplex_system_prompt",
|
|
},
|
|
"barge_in": {
|
|
"min_duration_ms": "barge_in_min_duration_ms",
|
|
"silence_tolerance_ms": "barge_in_silence_tolerance_ms",
|
|
},
|
|
}
|
|
_AGENT_SETTING_KEYS = {
|
|
"vad_type",
|
|
"vad_model_path",
|
|
"vad_threshold",
|
|
"vad_min_speech_duration_ms",
|
|
"vad_eou_threshold_ms",
|
|
"llm_provider",
|
|
"llm_api_key",
|
|
"llm_api_url",
|
|
"llm_model",
|
|
"llm_temperature",
|
|
"tts_provider",
|
|
"tts_api_key",
|
|
"tts_api_url",
|
|
"tts_model",
|
|
"tts_voice",
|
|
"tts_speed",
|
|
"asr_provider",
|
|
"asr_api_key",
|
|
"asr_api_url",
|
|
"asr_model",
|
|
"asr_interim_interval_ms",
|
|
"asr_min_audio_ms",
|
|
"asr_start_min_speech_ms",
|
|
"asr_pre_speech_ms",
|
|
"asr_final_tail_ms",
|
|
"duplex_enabled",
|
|
"duplex_greeting",
|
|
"duplex_system_prompt",
|
|
"barge_in_min_duration_ms",
|
|
"barge_in_silence_tolerance_ms",
|
|
"tools",
|
|
}
|
|
_BASE_REQUIRED_AGENT_SETTING_KEYS = {
|
|
"vad_type",
|
|
"vad_model_path",
|
|
"vad_threshold",
|
|
"vad_min_speech_duration_ms",
|
|
"vad_eou_threshold_ms",
|
|
"llm_provider",
|
|
"llm_model",
|
|
"llm_temperature",
|
|
"tts_provider",
|
|
"tts_voice",
|
|
"tts_speed",
|
|
"asr_provider",
|
|
"asr_interim_interval_ms",
|
|
"asr_min_audio_ms",
|
|
"asr_start_min_speech_ms",
|
|
"asr_pre_speech_ms",
|
|
"asr_final_tail_ms",
|
|
"duplex_enabled",
|
|
"duplex_system_prompt",
|
|
"barge_in_min_duration_ms",
|
|
"barge_in_silence_tolerance_ms",
|
|
}
|
|
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
|
|
|
|
|
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
|
|
return str(overrides.get(key) or default).strip().lower()
|
|
|
|
|
|
def _is_blank(value: Any) -> bool:
|
|
return value is None or (isinstance(value, str) and not value.strip())
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentConfigSelection:
|
|
"""Resolved agent config location and how it was selected."""
|
|
|
|
path: Optional[Path]
|
|
source: str
|
|
|
|
|
|
def _parse_cli_agent_args(argv: List[str]) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Parse only agent-related CLI flags from argv."""
|
|
config_path: Optional[str] = None
|
|
profile: Optional[str] = None
|
|
i = 0
|
|
while i < len(argv):
|
|
arg = argv[i]
|
|
if arg.startswith("--agent-config="):
|
|
config_path = arg.split("=", 1)[1].strip() or None
|
|
elif arg == "--agent-config" and i + 1 < len(argv):
|
|
config_path = argv[i + 1].strip() or None
|
|
i += 1
|
|
elif arg.startswith("--agent-profile="):
|
|
profile = arg.split("=", 1)[1].strip() or None
|
|
elif arg == "--agent-profile" and i + 1 < len(argv):
|
|
profile = argv[i + 1].strip() or None
|
|
i += 1
|
|
i += 1
|
|
return config_path, profile
|
|
|
|
|
|
def _agent_config_dir() -> Path:
|
|
base_dir = Path(os.getenv("AGENT_CONFIG_DIR", _DEFAULT_AGENT_CONFIG_DIR))
|
|
if not base_dir.is_absolute():
|
|
base_dir = Path.cwd() / base_dir
|
|
return base_dir.resolve()
|
|
|
|
|
|
def _resolve_agent_selection(
|
|
agent_config_path: Optional[str] = None,
|
|
agent_profile: Optional[str] = None,
|
|
argv: Optional[List[str]] = None,
|
|
) -> AgentConfigSelection:
|
|
cli_path, cli_profile = _parse_cli_agent_args(list(argv if argv is not None else sys.argv[1:]))
|
|
path_value = agent_config_path or cli_path or os.getenv("AGENT_CONFIG_PATH")
|
|
profile_value = agent_profile or cli_profile or os.getenv("AGENT_PROFILE")
|
|
source = "none"
|
|
candidate: Optional[Path] = None
|
|
|
|
if path_value:
|
|
source = "cli_path" if (agent_config_path or cli_path) else "env_path"
|
|
candidate = Path(path_value)
|
|
elif profile_value:
|
|
source = "cli_profile" if (agent_profile or cli_profile) else "env_profile"
|
|
candidate = _agent_config_dir() / f"{profile_value}.yaml"
|
|
else:
|
|
fallback = _agent_config_dir() / _DEFAULT_AGENT_CONFIG_FILE
|
|
if fallback.exists():
|
|
source = "default"
|
|
candidate = fallback
|
|
|
|
if candidate is None:
|
|
raise ValueError(
|
|
"Agent YAML config is required. Provide --agent-config/--agent-profile "
|
|
"or create config/agents/default.yaml."
|
|
)
|
|
|
|
if not candidate.is_absolute():
|
|
candidate = (Path.cwd() / candidate).resolve()
|
|
else:
|
|
candidate = candidate.resolve()
|
|
|
|
if not candidate.exists():
|
|
raise ValueError(f"Agent config file not found ({source}): {candidate}")
|
|
if not candidate.is_file():
|
|
raise ValueError(f"Agent config path is not a file: {candidate}")
|
|
return AgentConfigSelection(path=candidate, source=source)
|
|
|
|
|
|
def _resolve_env_refs(value: Any) -> Any:
|
|
"""Resolve ${ENV_VAR} / ${ENV_VAR:default} placeholders recursively."""
|
|
if isinstance(value, dict):
|
|
return {k: _resolve_env_refs(v) for k, v in value.items()}
|
|
if isinstance(value, list):
|
|
return [_resolve_env_refs(item) for item in value]
|
|
if not isinstance(value, str) or "${" not in value:
|
|
return value
|
|
|
|
def _replace(match: re.Match[str]) -> str:
|
|
env_key = match.group(1)
|
|
default_value = match.group(2)
|
|
env_value = os.getenv(env_key)
|
|
if env_value is None:
|
|
if default_value is None:
|
|
raise ValueError(f"Missing environment variable referenced in agent YAML: {env_key}")
|
|
return default_value
|
|
return env_value
|
|
|
|
return _ENV_REF_PATTERN.sub(_replace, value)
|
|
|
|
|
|
def _normalize_agent_overrides(raw: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Normalize YAML into flat Settings fields."""
|
|
normalized: Dict[str, Any] = {}
|
|
|
|
for key, value in raw.items():
|
|
if key == "siliconflow":
|
|
raise ValueError(
|
|
"Section 'siliconflow' is no longer supported. "
|
|
"Move provider-specific fields into agent.llm / agent.asr / agent.tts."
|
|
)
|
|
if key == "tools":
|
|
if not isinstance(value, list):
|
|
raise ValueError("Agent config key 'tools' must be a list")
|
|
normalized["tools"] = value
|
|
continue
|
|
section_map = _AGENT_SECTION_KEY_MAP.get(key)
|
|
if section_map is None:
|
|
normalized[key] = value
|
|
continue
|
|
|
|
if not isinstance(value, dict):
|
|
raise ValueError(f"Agent config section '{key}' must be a mapping")
|
|
|
|
for nested_key, nested_value in value.items():
|
|
mapped_key = section_map.get(nested_key)
|
|
if mapped_key is None:
|
|
raise ValueError(f"Unknown key in '{key}' section: '{nested_key}'")
|
|
normalized[mapped_key] = nested_value
|
|
|
|
unknown_keys = sorted(set(normalized) - _AGENT_SETTING_KEYS)
|
|
if unknown_keys:
|
|
raise ValueError(
|
|
"Unknown agent config keys in YAML: "
|
|
+ ", ".join(unknown_keys)
|
|
)
|
|
return normalized
|
|
|
|
|
|
def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
|
|
missing = set(_BASE_REQUIRED_AGENT_SETTING_KEYS - set(overrides))
|
|
string_required = {
|
|
"vad_type",
|
|
"vad_model_path",
|
|
"llm_provider",
|
|
"llm_model",
|
|
"tts_provider",
|
|
"tts_voice",
|
|
"asr_provider",
|
|
"duplex_system_prompt",
|
|
}
|
|
for key in string_required:
|
|
if key in overrides and _is_blank(overrides.get(key)):
|
|
missing.add(key)
|
|
|
|
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
|
|
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
|
|
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
|
|
missing.add("llm_api_key")
|
|
|
|
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
|
|
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
|
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
|
|
missing.add("tts_api_key")
|
|
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
|
|
missing.add("tts_api_url")
|
|
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
|
|
missing.add("tts_model")
|
|
|
|
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
|
|
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
|
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
|
|
missing.add("asr_api_key")
|
|
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
|
|
missing.add("asr_api_url")
|
|
if "asr_model" not in overrides or _is_blank(overrides.get("asr_model")):
|
|
missing.add("asr_model")
|
|
|
|
return sorted(missing)
|
|
|
|
|
|
def _load_agent_overrides(selection: AgentConfigSelection) -> Dict[str, Any]:
|
|
if yaml is None:
|
|
raise RuntimeError(
|
|
"PyYAML is required for agent YAML configuration. Install with: pip install pyyaml"
|
|
)
|
|
|
|
with selection.path.open("r", encoding="utf-8") as file:
|
|
raw = yaml.safe_load(file) or {}
|
|
|
|
if not isinstance(raw, dict):
|
|
raise ValueError(f"Agent config must be a YAML mapping: {selection.path}")
|
|
|
|
if "agent" in raw:
|
|
agent_value = raw["agent"]
|
|
if not isinstance(agent_value, dict):
|
|
raise ValueError("The 'agent' key in YAML must be a mapping")
|
|
raw = agent_value
|
|
|
|
resolved = _resolve_env_refs(raw)
|
|
overrides = _normalize_agent_overrides(resolved)
|
|
missing_required = _missing_required_keys(overrides)
|
|
if missing_required:
|
|
raise ValueError(
|
|
f"Missing required agent settings in YAML ({selection.path}): "
|
|
+ ", ".join(missing_required)
|
|
)
|
|
|
|
overrides["agent_config_path"] = str(selection.path)
|
|
overrides["agent_config_source"] = selection.source
|
|
return overrides
|
|
|
|
|
|
def load_settings(
|
|
agent_config_path: Optional[str] = None,
|
|
agent_profile: Optional[str] = None,
|
|
argv: Optional[List[str]] = None,
|
|
) -> "Settings":
|
|
"""Load settings from .env and optional agent YAML."""
|
|
selection = _resolve_agent_selection(
|
|
agent_config_path=agent_config_path,
|
|
agent_profile=agent_profile,
|
|
argv=argv,
|
|
)
|
|
agent_overrides = _load_agent_overrides(selection)
|
|
return Settings(**agent_overrides)
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Application settings loaded from environment variables."""
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
case_sensitive=False,
|
|
extra="ignore"
|
|
)
|
|
|
|
# Server Configuration
|
|
host: str = Field(default="0.0.0.0", description="Server host address")
|
|
port: int = Field(default=8000, description="Server port")
|
|
external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal")
|
|
|
|
# Audio Configuration
|
|
sample_rate: int = Field(default=16000, description="Audio sample rate in Hz")
|
|
chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds")
|
|
default_codec: str = Field(default="pcm", description="Default audio codec")
|
|
max_audio_buffer_seconds: int = Field(
|
|
default=30,
|
|
description="Maximum buffered user audio duration kept in memory for current turn"
|
|
)
|
|
|
|
# VAD Configuration
|
|
vad_type: str = Field(default="silero", description="VAD algorithm type")
|
|
vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model")
|
|
vad_threshold: float = Field(default=0.5, description="VAD detection threshold")
|
|
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
|
|
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
|
|
|
|
# LLM Configuration
|
|
llm_provider: str = Field(
|
|
default="openai",
|
|
description="LLM provider (openai, openai_compatible, siliconflow)"
|
|
)
|
|
llm_api_key: Optional[str] = Field(default=None, description="LLM provider API key")
|
|
llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL")
|
|
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
|
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
|
|
|
# TTS Configuration
|
|
tts_provider: str = Field(
|
|
default="openai_compatible",
|
|
description="TTS provider (edge, openai_compatible, siliconflow)"
|
|
)
|
|
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
|
|
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
|
|
tts_model: Optional[str] = Field(default=None, description="TTS model name")
|
|
tts_voice: str = Field(default="anna", description="TTS voice name")
|
|
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
|
|
|
# ASR Configuration
|
|
asr_provider: str = Field(
|
|
default="openai_compatible",
|
|
description="ASR provider (openai_compatible, buffered, siliconflow)"
|
|
)
|
|
asr_api_key: Optional[str] = Field(default=None, description="ASR provider API key")
|
|
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
|
|
asr_model: Optional[str] = Field(default=None, description="ASR model name")
|
|
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
|
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
|
asr_start_min_speech_ms: int = Field(
|
|
default=160,
|
|
description="Minimum continuous speech duration before ASR capture starts"
|
|
)
|
|
asr_pre_speech_ms: int = Field(
|
|
default=240,
|
|
description="Audio context (ms) prepended before detected speech to avoid clipping first phoneme"
|
|
)
|
|
asr_final_tail_ms: int = Field(
|
|
default=120,
|
|
description="Silence tail (ms) appended before final ASR decode to protect utterance ending"
|
|
)
|
|
|
|
# Duplex Pipeline Configuration
|
|
duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline")
|
|
duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message")
|
|
duplex_system_prompt: Optional[str] = Field(
|
|
default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.",
|
|
description="System prompt for LLM"
|
|
)
|
|
|
|
# Barge-in (interruption) Configuration
|
|
barge_in_min_duration_ms: int = Field(
|
|
default=200,
|
|
description="Minimum speech duration (ms) required to trigger barge-in. Lower=more sensitive."
|
|
)
|
|
barge_in_silence_tolerance_ms: int = Field(
|
|
default=60,
|
|
description="How much silence (ms) is tolerated during potential barge-in before reset"
|
|
)
|
|
|
|
# Optional tool declarations from agent YAML.
|
|
# Supports OpenAI function schema style entries and/or shorthand string names.
|
|
tools: List[Any] = Field(default_factory=list, description="Default tool definitions for runtime")
|
|
|
|
# Logging
|
|
log_level: str = Field(default="INFO", description="Logging level")
|
|
log_format: str = Field(default="json", description="Log format (json or text)")
|
|
|
|
# CORS
|
|
cors_origins: str = Field(
|
|
default='["http://localhost:3000", "http://localhost:8080"]',
|
|
description="CORS allowed origins"
|
|
)
|
|
|
|
# ICE Servers (WebRTC)
|
|
ice_servers: str = Field(
|
|
default='[{"urls": "stun:stun.l.google.com:19302"}]',
|
|
description="ICE servers configuration"
|
|
)
|
|
|
|
# WebSocket heartbeat and inactivity
|
|
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
|
|
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
|
|
ws_protocol_version: str = Field(default="v1", description="Public WS protocol version")
|
|
ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth")
|
|
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
|
|
|
|
# Backend bridge configuration (for call/transcript persistence)
|
|
backend_mode: str = Field(
|
|
default="auto",
|
|
description="Backend integration mode: auto | http | disabled"
|
|
)
|
|
backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)")
|
|
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
|
|
history_enabled: bool = Field(default=True, description="Enable history write bridge")
|
|
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
|
|
history_queue_max_size: int = Field(default=256, description="Max buffered transcript writes per session")
|
|
history_retry_max_attempts: int = Field(default=2, description="Retry attempts for each transcript write")
|
|
history_retry_backoff_sec: float = Field(default=0.2, description="Base retry backoff for transcript writes")
|
|
history_finalize_drain_timeout_sec: float = Field(
|
|
default=1.5,
|
|
description="Max wait before finalizing history when queue is still draining"
|
|
)
|
|
|
|
# Agent YAML metadata
|
|
agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path")
|
|
agent_config_source: str = Field(default="none", description="How the agent YAML was selected")
|
|
|
|
@property
|
|
def chunk_size_bytes(self) -> int:
|
|
"""Calculate chunk size in bytes based on sample rate and duration."""
|
|
# 16-bit (2 bytes) per sample, mono channel
|
|
return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0))
|
|
|
|
@property
|
|
def cors_origins_list(self) -> List[str]:
|
|
"""Parse CORS origins from JSON string."""
|
|
try:
|
|
return json.loads(self.cors_origins)
|
|
except json.JSONDecodeError:
|
|
return ["http://localhost:3000", "http://localhost:8080"]
|
|
|
|
@property
|
|
def ice_servers_list(self) -> List[dict]:
|
|
"""Parse ICE servers from JSON string."""
|
|
try:
|
|
return json.loads(self.ice_servers)
|
|
except json.JSONDecodeError:
|
|
return [{"urls": "stun:stun.l.google.com:19302"}]
|
|
|
|
|
|
# Global settings instance
|
|
settings = load_settings()
|
|
|
|
|
|
def get_settings() -> Settings:
|
|
"""Get application settings instance."""
|
|
return settings
|