119 lines
3.4 KiB
Python
119 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ServerConfig:
|
|
host: str = "0.0.0.0"
|
|
port: int = 8000
|
|
cors_origins: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AudioConfig:
|
|
sample_rate_hz: int = 16000
|
|
channels: int = 1
|
|
frame_ms: int = 20
|
|
|
|
@property
|
|
def frame_bytes(self) -> int:
|
|
return int(self.sample_rate_hz * self.frame_ms / 1000) * self.channels * 2
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SessionConfig:
|
|
inactivity_timeout_sec: int = 60
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentConfig:
|
|
system_prompt: str = "You are a helpful, friendly voice assistant."
|
|
greeting: str | None = None
|
|
greeting_mode: str = "generated"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMConfig:
|
|
provider: str = "openai"
|
|
api_key: str = ""
|
|
base_url: str | None = None
|
|
model: str = "gpt-4o-mini"
|
|
temperature: float | None = 0.7
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class STTConfig:
|
|
provider: str = "openai"
|
|
api_key: str = ""
|
|
base_url: str | None = None
|
|
model: str = "gpt-4o-mini-transcribe"
|
|
language: str | None = "en"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TTSConfig:
|
|
provider: str = "openai"
|
|
api_key: str = ""
|
|
base_url: str | None = None
|
|
model: str = "gpt-4o-mini-tts"
|
|
voice: str = "alloy"
|
|
source_sample_rate_hz: int | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ServicesConfig:
|
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
|
stt: STTConfig = field(default_factory=STTConfig)
|
|
tts: TTSConfig = field(default_factory=TTSConfig)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class EngineConfig:
|
|
server: ServerConfig = field(default_factory=ServerConfig)
|
|
audio: AudioConfig = field(default_factory=AudioConfig)
|
|
session: SessionConfig = field(default_factory=SessionConfig)
|
|
agent: AgentConfig = field(default_factory=AgentConfig)
|
|
services: ServicesConfig = field(default_factory=ServicesConfig)
|
|
|
|
|
|
def load_config(path: str | Path = "config.json") -> EngineConfig:
|
|
config_path = Path(path)
|
|
if not config_path.exists() and str(path) == "config.json":
|
|
config_path = Path(__file__).resolve().parent.parent / "config.json"
|
|
data = json.loads(config_path.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
raise ValueError(f"Config file must contain a JSON object: {config_path}")
|
|
return config_from_dict(data)
|
|
|
|
|
|
def config_from_dict(data: dict) -> EngineConfig:
|
|
services = _dict(data.get("services"))
|
|
agent = _dict(data.get("agent"))
|
|
if agent.get("greeting") == "":
|
|
agent["greeting"] = None
|
|
if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"):
|
|
raise ValueError("agent.greeting_mode must be one of: generated, fixed, off")
|
|
|
|
stt = _dict(services.get("stt") or services.get("asr"))
|
|
if stt.get("language") == "":
|
|
stt["language"] = None
|
|
|
|
return EngineConfig(
|
|
server=ServerConfig(**_dict(data.get("server"))),
|
|
audio=AudioConfig(**_dict(data.get("audio"))),
|
|
session=SessionConfig(**_dict(data.get("session"))),
|
|
agent=AgentConfig(**agent),
|
|
services=ServicesConfig(
|
|
llm=LLMConfig(**_dict(services.get("llm"))),
|
|
stt=STTConfig(**stt),
|
|
tts=TTSConfig(**_dict(services.get("tts"))),
|
|
),
|
|
)
|
|
|
|
|
|
def _dict(value: object) -> dict:
|
|
return dict(value) if isinstance(value, dict) else {}
|