Files
engine-v5-pipecat-core/engine/config.py
2026-05-29 11:01:24 +08:00

328 lines
10 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
SUPPORTED_LLM_PROVIDERS = frozenset({"openai", "fastgpt"})
_LLM_PROVIDER_ALIASES = {"llm": "openai", "openai": "openai", "fastgpt": "fastgpt"}
@dataclass(frozen=True)
class ServerConfig:
host: str = "0.0.0.0"
port: int = 8000
cors_origins: list[str] = field(default_factory=list)
serve_webpage: bool = True
webpage_mount: str = "/voice-demo"
@dataclass(frozen=True)
class AudioConfig:
sample_rate_hz: int = 16000
channels: int = 1
frame_ms: int = 20
@property
def frame_bytes(self) -> int:
return int(self.sample_rate_hz * self.frame_ms / 1000) * self.channels * 2
@dataclass(frozen=True)
class AudioFilterConfig:
"""Optional input audio filter applied by the Pipecat transport before VAD/STT."""
enabled: bool = False
provider: str = "none"
lib_path: str | None = None
model_path: str | None = None
model_sample_rate_hz: int = 48000
atten_lim_db: float = 100.0
post_filter_beta: float = 0.0
log_level: str | None = None
@dataclass(frozen=True)
class SessionConfig:
inactivity_timeout_sec: int = 60
@dataclass(frozen=True)
class VADConfig:
"""Voice Activity Detection thresholds for the Silero analyzer.
These map directly to ``pipecat.audio.vad.vad_analyzer.VADParams``.
Defaults are tuned a touch more conservative than upstream pipecat so
short pauses in continuous speech don't end the user turn prematurely.
"""
confidence: float = 0.7
start_secs: float = 0.2
stop_secs: float = 0.6
min_volume: float = 0.6
@dataclass(frozen=True)
class TurnConfig:
"""User-turn segmentation policy.
``user_speech_timeout_sec`` is the grace window (in seconds) after VAD
has confirmed silence during which the user is allowed to resume
speaking before the aggregator finalizes the turn. Used by
``SpeechTimeoutUserTurnStopStrategy``. Higher = more tolerant of
natural mid-sentence pauses; lower = snappier turn-taking.
The combined "user pause before turn ends" budget is roughly
``vad.stop_secs + user_speech_timeout_sec``.
"""
vad: VADConfig = field(default_factory=VADConfig)
user_speech_timeout_sec: float = 1.0
idle_prompt_timeout_sec: float = 0.0
idle_prompt_max_count: int = 1
idle_prompt_text: str = (
"我先停在这里。你可以继续说你的想法,"
"或者让我根据刚才的内容帮你整理下一步。"
)
interruption_min_chars: int = 3
interruption_use_interim: bool = True
interruption_short_replies: list[str] = field(
default_factory=lambda: [
"",
"是的",
"",
"对的",
"",
"",
"好的",
"",
"可以",
"没问题",
"不是",
"",
"不行",
"不用",
"不要",
"没有",
"",
"no",
"yes",
"ok",
"okay",
]
)
@dataclass(frozen=True)
class ResponseStateConfig:
enabled: bool = False
tag: str = "state"
event_type: str = "response.state"
max_prefix_chars: int = 256
@dataclass(frozen=True)
class AgentConfig:
system_prompt: str = "You are a helpful, friendly voice assistant."
greeting: str | None = None
greeting_mode: str = "generated"
response_state: ResponseStateConfig = field(default_factory=ResponseStateConfig)
@dataclass(frozen=True)
class LLMConfig:
"""LLM backend selection via ``provider``.
Set ``provider`` to ``"openai"`` (alias ``"llm"``) for OpenAI-compatible chat
completions, or ``"fastgpt"`` for FastGPT server-side memory via ``chat_id``.
"""
provider: str = "openai"
api_key: str = ""
base_url: str | None = None
model: str = "gpt-4o-mini"
app_id: str | None = None
temperature: float | None = 0.7
chat_id: str | None = None
variables: dict[str, str] = field(default_factory=dict)
detail: bool = False
timeout_sec: float = 60.0
send_system_prompt: bool = False
@property
def is_fastgpt(self) -> bool:
return self.provider == "fastgpt"
@property
def is_openai(self) -> bool:
return self.provider == "openai"
@property
def uses_local_context_history(self) -> bool:
"""Whether the pipeline should seed and maintain local LLM context history."""
return not self.is_fastgpt or self.send_system_prompt
@dataclass(frozen=True)
class STTConfig:
provider: str = "openai"
app_id: str = ""
api_key: str = ""
api_secret: str = ""
base_url: str | None = None
model: str = "gpt-4o-mini-transcribe"
language: str | None = "en"
domain: str = "iat"
accent: str = "mandarin"
encoding: str = "raw"
frame_size: int = 1280
timeout_sec: float = 10.0
dynamic_correction: bool = False
@dataclass(frozen=True)
class TTSConfig:
provider: str = "openai"
app_id: str = ""
api_key: str = ""
api_secret: str = ""
base_url: str | None = None
model: str = "gpt-4o-mini-tts"
voice: str = "alloy"
aue: str = "raw"
tte: str = "UTF8"
speed: int = 50
volume: int = 50
pitch: int = 50
timeout_sec: float = 30.0
source_sample_rate_hz: int | None = None
oral_level: str = "mid"
text_aggregation_mode: str | None = None
@dataclass(frozen=True)
class ServicesConfig:
llm: LLMConfig = field(default_factory=LLMConfig)
stt: STTConfig = field(default_factory=STTConfig)
tts: TTSConfig = field(default_factory=TTSConfig)
@dataclass(frozen=True)
class EngineConfig:
server: ServerConfig = field(default_factory=ServerConfig)
audio: AudioConfig = field(default_factory=AudioConfig)
audio_filter: AudioFilterConfig = field(default_factory=AudioFilterConfig)
session: SessionConfig = field(default_factory=SessionConfig)
turn: TurnConfig = field(default_factory=TurnConfig)
agent: AgentConfig = field(default_factory=AgentConfig)
services: ServicesConfig = field(default_factory=ServicesConfig)
def load_config(path: str | Path = "config.json") -> EngineConfig:
config_path = Path(path)
if not config_path.exists() and str(path) == "config.json":
config_path = Path(__file__).resolve().parent.parent / "config.json"
data = json.loads(config_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"Config file must contain a JSON object: {config_path}")
return config_from_dict(data)
def config_from_dict(data: dict) -> EngineConfig:
services = _dict(data.get("services"))
agent = _dict(data.get("agent"))
if agent.get("greeting") == "":
agent["greeting"] = None
if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"):
raise ValueError("agent.greeting_mode must be one of: generated, fixed, off")
response_state = ResponseStateConfig(**_dict(agent.pop("response_state")))
if response_state.max_prefix_chars < 1:
raise ValueError("agent.response_state.max_prefix_chars must be greater than 0")
if not response_state.tag:
raise ValueError("agent.response_state.tag must not be empty")
if not response_state.event_type:
raise ValueError("agent.response_state.event_type must not be empty")
stt = _dict(services.get("stt") or services.get("asr"))
if stt.get("language") == "":
stt["language"] = None
llm = _dict(services.get("llm"))
llm["provider"] = _normalize_llm_provider(llm.get("provider", LLMConfig().provider))
if llm.get("chat_id") == "":
llm["chat_id"] = None
if llm.get("app_id") == "":
llm["app_id"] = None
if not isinstance(llm.get("variables"), dict):
llm["variables"] = {}
turn = _dict(data.get("turn"))
vad = _dict(turn.get("vad"))
return EngineConfig(
server=ServerConfig(**_dict(data.get("server"))),
audio=AudioConfig(**_dict(data.get("audio"))),
audio_filter=AudioFilterConfig(**_normalize_audio_filter(_dict(data.get("audio_filter")))),
session=SessionConfig(**_dict(data.get("session"))),
turn=TurnConfig(
vad=VADConfig(**vad),
user_speech_timeout_sec=float(
turn.get("user_speech_timeout_sec", TurnConfig().user_speech_timeout_sec)
),
idle_prompt_timeout_sec=float(
turn.get("idle_prompt_timeout_sec", TurnConfig().idle_prompt_timeout_sec)
),
idle_prompt_max_count=int(
turn.get("idle_prompt_max_count", TurnConfig().idle_prompt_max_count)
),
idle_prompt_text=str(
turn.get("idle_prompt_text", TurnConfig().idle_prompt_text)
),
interruption_min_chars=int(
turn.get("interruption_min_chars", TurnConfig().interruption_min_chars)
),
interruption_use_interim=bool(
turn.get("interruption_use_interim", TurnConfig().interruption_use_interim)
),
interruption_short_replies=list(
turn.get(
"interruption_short_replies",
TurnConfig().interruption_short_replies,
)
),
),
agent=AgentConfig(**agent, response_state=response_state),
services=ServicesConfig(
llm=LLMConfig(**llm),
stt=STTConfig(**stt),
tts=TTSConfig(**_dict(services.get("tts"))),
),
)
def _dict(value: object) -> dict:
return dict(value) if isinstance(value, dict) else {}
def _normalize_audio_filter(value: dict) -> dict:
if value.get("lib_path") == "":
value["lib_path"] = None
if value.get("model_path") == "":
value["model_path"] = None
if value.get("log_level") == "":
value["log_level"] = None
if "provider" in value:
value["provider"] = str(value["provider"]).strip().lower()
return value
def _normalize_llm_provider(value: object) -> str:
provider = str(value or LLMConfig().provider).strip().lower()
normalized = _LLM_PROVIDER_ALIASES.get(provider)
if normalized is None:
supported = ", ".join(sorted(SUPPORTED_LLM_PROVIDERS | {"llm"}))
raise ValueError(
f"services.llm.provider must be one of: {supported}; got {value!r}"
)
return normalized