Files
engine-v5-pipecat-core/engine/services.py
2026-06-02 08:24:53 +08:00

221 lines
8.2 KiB
Python

from __future__ import annotations
from collections.abc import AsyncGenerator
from openai import BadRequestError
from openai import NOT_GIVEN
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
from pipecat.services.tts_service import TextAggregationMode
from pipecat.transcriptions.language import Language
from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig
from .fastgpt_llm import FastGPTLLMService, FastGPTLLMSettings
from .xfyun_asr import DEFAULT_XFYUN_ASR_URL, XfyunASRService
from .xfyun_super_tts import DEFAULT_XFYUN_SUPER_TTS_URL, XfyunSuperTTSService
from .xfyun_tts import DEFAULT_XFYUN_TTS_URL, XfyunTTSService
def create_stt_service(config: STTConfig, audio: AudioConfig | None = None):
if config.provider == "xfyun":
sample_rate = audio.sample_rate_hz if audio else 16000
return XfyunASRService(
app_id=config.app_id,
api_key=config.api_key or "",
api_secret=config.api_secret,
url=config.base_url or DEFAULT_XFYUN_ASR_URL,
language=config.language or "zh_cn",
domain=config.domain,
accent=config.accent,
sample_rate=sample_rate,
encoding=config.encoding,
frame_size=config.frame_size,
open_timeout=config.timeout_sec,
dynamic_correction=config.dynamic_correction,
)
_require_provider(config.provider, "openai", "stt")
return OpenAISTTService(
api_key=config.api_key or None,
base_url=config.base_url,
settings=OpenAISTTService.Settings(
model=config.model,
language=_language(config.language),
),
)
def create_llm_service(
config: LLMConfig,
*,
chat_id: str | None = None,
session_variables: dict | None = None,
greeting_prompt: str | None = None,
):
if config.is_fastgpt:
variables = {**config.variables, **(session_variables or {})}
return FastGPTLLMService(
api_key=config.api_key,
base_url=config.base_url or "http://localhost:3000",
chat_id=chat_id,
app_id=config.app_id,
greeting_prompt=greeting_prompt,
timeout=config.timeout_sec,
image_input_mode=config.image_input_mode,
settings=FastGPTLLMSettings(
model=config.model or "fastgpt",
variables=variables,
detail=config.detail,
),
)
if not config.is_openai:
supported = ", ".join(sorted(("openai", "fastgpt", "llm")))
raise ValueError(
f"Unsupported llm provider {config.provider!r}; expected one of: {supported}"
)
return OpenAILLMService(
api_key=config.api_key or None,
base_url=config.base_url,
settings=OpenAILLMService.Settings(
model=config.model,
temperature=config.temperature if config.temperature is not None else NOT_GIVEN,
),
)
def create_tts_service(config: TTSConfig, audio: AudioConfig):
if config.provider == "xfyun":
source_sample_rate = config.source_sample_rate_hz or audio.sample_rate_hz
if source_sample_rate not in (8000, 16000):
raise ValueError("Xfyun TTS source_sample_rate_hz must be 8000 or 16000")
return XfyunTTSService(
app_id=config.app_id,
api_key=config.api_key or "",
api_secret=config.api_secret,
voice=config.voice,
url=config.base_url or DEFAULT_XFYUN_TTS_URL,
sample_rate=audio.sample_rate_hz,
source_sample_rate=source_sample_rate,
encoding=config.aue,
text_encoding=config.tte,
speed=config.speed,
volume=config.volume,
pitch=config.pitch,
timeout=config.timeout_sec,
push_stop_frames=True,
)
if config.provider in ("xfyun_super", "xfyun_super_tts"):
source_sample_rate = config.source_sample_rate_hz or 24000
if source_sample_rate not in (8000, 16000, 24000):
raise ValueError(
"Xfyun Super TTS source_sample_rate_hz must be 8000, 16000, or 24000"
)
text_aggregation_mode = config.text_aggregation_mode or TextAggregationMode.TOKEN
return XfyunSuperTTSService(
app_id=config.app_id,
api_key=config.api_key or "",
api_secret=config.api_secret,
voice=config.voice,
url=config.base_url or DEFAULT_XFYUN_SUPER_TTS_URL,
sample_rate=audio.sample_rate_hz,
source_sample_rate=source_sample_rate,
encoding=config.aue,
speed=config.speed,
volume=config.volume,
pitch=config.pitch,
oral_level=config.oral_level,
text_aggregation_mode=text_aggregation_mode,
open_timeout=config.timeout_sec,
)
_require_provider(config.provider, "openai", "tts")
service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService
return service_class(
api_key=config.api_key or None,
base_url=config.base_url,
sample_rate=audio.sample_rate_hz,
source_sample_rate=config.source_sample_rate_hz,
settings=OpenAITTSService.Settings(
model=config.model,
voice=config.voice,
),
)
class OpenAICompatibleTTSService(OpenAITTSService):
"""OpenAI-compatible TTS service that permits provider-specific voice ids."""
def __init__(self, *, source_sample_rate: int | None = None, **kwargs):
super().__init__(**kwargs)
self._source_sample_rate = source_sample_rate or OPENAI_SAMPLE_RATE
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
voice = self._settings.voice
if not voice:
yield ErrorFrame(error="TTS voice must be specified")
return
try:
create_params = {
"input": text,
"model": self._settings.model,
"voice": voice,
"response_format": "pcm",
}
if self._settings.instructions:
create_params["instructions"] = self._settings.instructions
if self._settings.speed:
create_params["speed"] = self._settings.speed
async with self._client.audio.speech.with_streaming_response.create(
**create_params
) as response:
if response.status_code != 200:
error = await response.text()
yield ErrorFrame(
error=f"TTS request failed (status: {response.status_code}, error: {error})"
)
return
await self.start_tts_usage_metrics(text)
async def audio_chunks():
async for chunk in response.iter_bytes(self.chunk_size):
if chunk:
yield chunk
first_frame = True
async for frame in self._stream_audio_frames_from_iterator(
audio_chunks(),
in_sample_rate=self._source_sample_rate,
context_id=context_id,
):
if first_frame:
await self.stop_ttfb_metrics()
first_frame = False
yield frame
except BadRequestError as exc:
yield ErrorFrame(error=f"TTS request failed: {exc}")
except Exception as exc:
yield ErrorFrame(error=f"TTS request failed: {exc}")
def _require_provider(actual: str, expected: str, service_name: str) -> None:
if actual != expected:
raise ValueError(f"Unsupported {service_name} provider {actual!r}; expected {expected!r}")
def _language(value: str | None) -> Language | None:
if value is None:
return None
normalized = value.replace("-", "_").upper()
return getattr(Language, normalized, value)