345 lines
11 KiB
Python
345 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
SUPPORTED_LLM_PROVIDERS = frozenset({"openai", "fastgpt"})
|
|
_LLM_PROVIDER_ALIASES = {"llm": "openai", "openai": "openai", "fastgpt": "fastgpt"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ServerConfig:
|
|
host: str = "0.0.0.0"
|
|
port: int = 8000
|
|
cors_origins: list[str] = field(default_factory=list)
|
|
serve_webpage: bool = True
|
|
webpage_mount: str = "/voice-demo"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AudioConfig:
|
|
sample_rate_hz: int = 16000
|
|
channels: int = 1
|
|
frame_ms: int = 20
|
|
|
|
@property
|
|
def frame_bytes(self) -> int:
|
|
return int(self.sample_rate_hz * self.frame_ms / 1000) * self.channels * 2
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AudioFilterConfig:
|
|
"""Optional input audio filter applied by the Pipecat transport before VAD/STT."""
|
|
|
|
enabled: bool = False
|
|
provider: str = "none"
|
|
lib_path: str | None = None
|
|
model_path: str | None = None
|
|
model_sample_rate_hz: int = 48000
|
|
atten_lim_db: float = 100.0
|
|
post_filter_beta: float = 0.0
|
|
log_level: str | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SessionConfig:
|
|
inactivity_timeout_sec: int = 60
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VADConfig:
|
|
"""Voice Activity Detection thresholds for the Silero analyzer.
|
|
|
|
These map directly to ``pipecat.audio.vad.vad_analyzer.VADParams``.
|
|
Defaults are tuned a touch more conservative than upstream pipecat so
|
|
short pauses in continuous speech don't end the user turn prematurely.
|
|
"""
|
|
|
|
confidence: float = 0.7
|
|
start_secs: float = 0.2
|
|
stop_secs: float = 0.6
|
|
min_volume: float = 0.6
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TurnConfig:
|
|
"""User-turn segmentation policy.
|
|
|
|
``user_speech_timeout_sec`` is the grace window (in seconds) after VAD
|
|
has confirmed silence during which the user is allowed to resume
|
|
speaking before the aggregator finalizes the turn. Used by
|
|
``SpeechTimeoutUserTurnStopStrategy``. Higher = more tolerant of
|
|
natural mid-sentence pauses; lower = snappier turn-taking.
|
|
|
|
The combined "user pause before turn ends" budget is roughly
|
|
``vad.stop_secs + user_speech_timeout_sec``.
|
|
"""
|
|
|
|
vad: VADConfig = field(default_factory=VADConfig)
|
|
user_speech_timeout_sec: float = 1.0
|
|
idle_prompt_timeout_sec: float = 0.0
|
|
idle_prompt_max_count: int = 1
|
|
idle_prompt_text: str = (
|
|
"我先停在这里。你可以继续说你的想法,"
|
|
"或者让我根据刚才的内容帮你整理下一步。"
|
|
)
|
|
interruption_min_chars: int = 3
|
|
interruption_use_interim: bool = True
|
|
interruption_short_replies: list[str] = field(
|
|
default_factory=lambda: [
|
|
"是",
|
|
"是的",
|
|
"对",
|
|
"对的",
|
|
"嗯",
|
|
"好",
|
|
"好的",
|
|
"行",
|
|
"可以",
|
|
"没问题",
|
|
"不是",
|
|
"不",
|
|
"不行",
|
|
"不用",
|
|
"不要",
|
|
"没有",
|
|
"否",
|
|
"no",
|
|
"yes",
|
|
"ok",
|
|
"okay",
|
|
]
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ResponseStateConfig:
|
|
enabled: bool = False
|
|
tag: str = "state"
|
|
event_type: str = "response.state"
|
|
max_prefix_chars: int = 256
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentConfig:
|
|
system_prompt: str = "You are a helpful, friendly voice assistant."
|
|
greeting: str | None = None
|
|
greeting_mode: str = "generated"
|
|
fastgpt_reconnect_greeting: str = "欢迎回来继续对话"
|
|
response_state: ResponseStateConfig = field(default_factory=ResponseStateConfig)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMConfig:
|
|
"""LLM backend selection via ``provider``.
|
|
|
|
Set ``provider`` to ``"openai"`` (alias ``"llm"``) for OpenAI-compatible chat
|
|
completions, or ``"fastgpt"`` for FastGPT server-side memory via runtime ``chatId``.
|
|
"""
|
|
|
|
provider: str = "openai"
|
|
api_key: str = ""
|
|
base_url: str | None = None
|
|
model: str = "gpt-4o-mini"
|
|
app_id: str | None = None
|
|
temperature: float | None = 0.7
|
|
chat_id: str | None = None
|
|
variables: dict[str, str] = field(default_factory=dict)
|
|
detail: bool = False
|
|
timeout_sec: float = 60.0
|
|
# FastGPT image input mode: "base64" (inline data URL) or "upload" (presigned upload).
|
|
image_input_mode: str = "base64"
|
|
|
|
@property
|
|
def is_fastgpt(self) -> bool:
|
|
return self.provider == "fastgpt"
|
|
|
|
@property
|
|
def is_openai(self) -> bool:
|
|
return self.provider == "openai"
|
|
|
|
@property
|
|
def uses_local_context_history(self) -> bool:
|
|
"""Whether the pipeline should seed and maintain local LLM context history."""
|
|
return not self.is_fastgpt
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class STTConfig:
|
|
provider: str = "openai"
|
|
app_id: str = ""
|
|
api_key: str = ""
|
|
api_secret: str = ""
|
|
base_url: str | None = None
|
|
model: str = "gpt-4o-mini-transcribe"
|
|
language: str | None = "en"
|
|
domain: str = "iat"
|
|
accent: str = "mandarin"
|
|
encoding: str = "raw"
|
|
frame_size: int = 1280
|
|
timeout_sec: float = 10.0
|
|
dynamic_correction: bool = False
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TTSConfig:
|
|
provider: str = "openai"
|
|
app_id: str = ""
|
|
api_key: str = ""
|
|
api_secret: str = ""
|
|
base_url: str | None = None
|
|
model: str = "gpt-4o-mini-tts"
|
|
voice: str = "alloy"
|
|
aue: str = "raw"
|
|
tte: str = "UTF8"
|
|
speed: int = 50
|
|
volume: int = 50
|
|
pitch: int = 50
|
|
timeout_sec: float = 30.0
|
|
source_sample_rate_hz: int | None = None
|
|
oral_level: str = "mid"
|
|
text_aggregation_mode: str | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ServicesConfig:
|
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
|
stt: STTConfig = field(default_factory=STTConfig)
|
|
tts: TTSConfig = field(default_factory=TTSConfig)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class EngineConfig:
|
|
server: ServerConfig = field(default_factory=ServerConfig)
|
|
audio: AudioConfig = field(default_factory=AudioConfig)
|
|
audio_filter: AudioFilterConfig = field(default_factory=AudioFilterConfig)
|
|
session: SessionConfig = field(default_factory=SessionConfig)
|
|
turn: TurnConfig = field(default_factory=TurnConfig)
|
|
agent: AgentConfig = field(default_factory=AgentConfig)
|
|
services: ServicesConfig = field(default_factory=ServicesConfig)
|
|
|
|
|
|
def load_config(path: str | Path = "config.json") -> EngineConfig:
|
|
config_path = Path(path)
|
|
if not config_path.exists() and str(path) == "config.json":
|
|
config_path = Path(__file__).resolve().parent.parent / "config.json"
|
|
data = json.loads(config_path.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
raise ValueError(f"Config file must contain a JSON object: {config_path}")
|
|
return config_from_dict(data)
|
|
|
|
|
|
def config_from_dict(data: dict) -> EngineConfig:
|
|
services = _dict(data.get("services"))
|
|
agent = _dict(data.get("agent"))
|
|
if agent.get("greeting") == "":
|
|
agent["greeting"] = None
|
|
if agent.get("greeting_mode") not in (None, "generated", "fixed", "off", "fastgpt_opener"):
|
|
raise ValueError(
|
|
"agent.greeting_mode must be one of: generated, fixed, off, fastgpt_opener"
|
|
)
|
|
response_state = ResponseStateConfig(**_dict(agent.pop("response_state", None)))
|
|
if response_state.max_prefix_chars < 1:
|
|
raise ValueError("agent.response_state.max_prefix_chars must be greater than 0")
|
|
if not response_state.tag:
|
|
raise ValueError("agent.response_state.tag must not be empty")
|
|
if not response_state.event_type:
|
|
raise ValueError("agent.response_state.event_type must not be empty")
|
|
|
|
stt = _dict(services.get("stt") or services.get("asr"))
|
|
if stt.get("language") == "":
|
|
stt["language"] = None
|
|
|
|
llm = _dict(services.get("llm"))
|
|
llm["provider"] = _normalize_llm_provider(llm.get("provider", LLMConfig().provider))
|
|
if llm.get("chat_id") == "":
|
|
llm["chat_id"] = None
|
|
if llm.get("app_id") == "":
|
|
llm["app_id"] = None
|
|
if not isinstance(llm.get("variables"), dict):
|
|
llm["variables"] = {}
|
|
image_input_mode = str(
|
|
llm.get("image_input_mode", LLMConfig().image_input_mode)
|
|
).strip().lower()
|
|
if image_input_mode not in {"base64", "upload"}:
|
|
raise ValueError(
|
|
"services.llm.image_input_mode must be 'base64' or 'upload', "
|
|
f"got {llm.get('image_input_mode')!r}"
|
|
)
|
|
llm["image_input_mode"] = image_input_mode
|
|
if agent.get("greeting_mode") == "fastgpt_opener" and llm["provider"] != "fastgpt":
|
|
raise ValueError(
|
|
"agent.greeting_mode='fastgpt_opener' requires services.llm.provider='fastgpt'"
|
|
)
|
|
|
|
turn = _dict(data.get("turn"))
|
|
vad = _dict(turn.get("vad"))
|
|
|
|
return EngineConfig(
|
|
server=ServerConfig(**_dict(data.get("server"))),
|
|
audio=AudioConfig(**_dict(data.get("audio"))),
|
|
audio_filter=AudioFilterConfig(**_normalize_audio_filter(_dict(data.get("audio_filter")))),
|
|
session=SessionConfig(**_dict(data.get("session"))),
|
|
turn=TurnConfig(
|
|
vad=VADConfig(**vad),
|
|
user_speech_timeout_sec=float(
|
|
turn.get("user_speech_timeout_sec", TurnConfig().user_speech_timeout_sec)
|
|
),
|
|
idle_prompt_timeout_sec=float(
|
|
turn.get("idle_prompt_timeout_sec", TurnConfig().idle_prompt_timeout_sec)
|
|
),
|
|
idle_prompt_max_count=int(
|
|
turn.get("idle_prompt_max_count", TurnConfig().idle_prompt_max_count)
|
|
),
|
|
idle_prompt_text=str(
|
|
turn.get("idle_prompt_text", TurnConfig().idle_prompt_text)
|
|
),
|
|
interruption_min_chars=int(
|
|
turn.get("interruption_min_chars", TurnConfig().interruption_min_chars)
|
|
),
|
|
interruption_use_interim=bool(
|
|
turn.get("interruption_use_interim", TurnConfig().interruption_use_interim)
|
|
),
|
|
interruption_short_replies=list(
|
|
turn.get(
|
|
"interruption_short_replies",
|
|
TurnConfig().interruption_short_replies,
|
|
)
|
|
),
|
|
),
|
|
agent=AgentConfig(**agent, response_state=response_state),
|
|
services=ServicesConfig(
|
|
llm=LLMConfig(**llm),
|
|
stt=STTConfig(**stt),
|
|
tts=TTSConfig(**_dict(services.get("tts"))),
|
|
),
|
|
)
|
|
|
|
|
|
def _dict(value: object) -> dict:
|
|
return dict(value) if isinstance(value, dict) else {}
|
|
|
|
|
|
def _normalize_audio_filter(value: dict) -> dict:
|
|
if value.get("lib_path") == "":
|
|
value["lib_path"] = None
|
|
if value.get("model_path") == "":
|
|
value["model_path"] = None
|
|
if value.get("log_level") == "":
|
|
value["log_level"] = None
|
|
if "provider" in value:
|
|
value["provider"] = str(value["provider"]).strip().lower()
|
|
return value
|
|
|
|
|
|
def _normalize_llm_provider(value: object) -> str:
|
|
provider = str(value or LLMConfig().provider).strip().lower()
|
|
normalized = _LLM_PROVIDER_ALIASES.get(provider)
|
|
if normalized is None:
|
|
supported = ", ".join(sorted(SUPPORTED_LLM_PROVIDERS | {"llm"}))
|
|
raise ValueError(
|
|
f"services.llm.provider must be one of: {supported}; got {value!r}"
|
|
)
|
|
return normalized
|