Files
engine-v5-pipecat-core/engine/services.py
2026-05-21 13:08:40 +08:00

127 lines
4.5 KiB
Python

from __future__ import annotations
from collections.abc import AsyncGenerator
from openai import BadRequestError
from openai import NOT_GIVEN
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
from pipecat.transcriptions.language import Language
from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig
def create_stt_service(config: STTConfig):
_require_provider(config.provider, "openai", "stt")
return OpenAISTTService(
api_key=config.api_key or None,
base_url=config.base_url,
settings=OpenAISTTService.Settings(
model=config.model,
language=_language(config.language),
),
)
def create_llm_service(config: LLMConfig):
_require_provider(config.provider, "openai", "llm")
return OpenAILLMService(
api_key=config.api_key or None,
base_url=config.base_url,
settings=OpenAILLMService.Settings(
model=config.model,
temperature=config.temperature if config.temperature is not None else NOT_GIVEN,
),
)
def create_tts_service(config: TTSConfig, audio: AudioConfig):
_require_provider(config.provider, "openai", "tts")
service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService
return service_class(
api_key=config.api_key or None,
base_url=config.base_url,
sample_rate=audio.sample_rate_hz,
source_sample_rate=config.source_sample_rate_hz,
settings=OpenAITTSService.Settings(
model=config.model,
voice=config.voice,
),
)
class OpenAICompatibleTTSService(OpenAITTSService):
"""OpenAI-compatible TTS service that permits provider-specific voice ids."""
def __init__(self, *, source_sample_rate: int | None = None, **kwargs):
super().__init__(**kwargs)
self._source_sample_rate = source_sample_rate or OPENAI_SAMPLE_RATE
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
voice = self._settings.voice
if not voice:
yield ErrorFrame(error="TTS voice must be specified")
return
try:
create_params = {
"input": text,
"model": self._settings.model,
"voice": voice,
"response_format": "pcm",
}
if self._settings.instructions:
create_params["instructions"] = self._settings.instructions
if self._settings.speed:
create_params["speed"] = self._settings.speed
async with self._client.audio.speech.with_streaming_response.create(
**create_params
) as response:
if response.status_code != 200:
error = await response.text()
yield ErrorFrame(
error=f"TTS request failed (status: {response.status_code}, error: {error})"
)
return
await self.start_tts_usage_metrics(text)
async def audio_chunks():
async for chunk in response.iter_bytes(self.chunk_size):
if chunk:
yield chunk
first_frame = True
async for frame in self._stream_audio_frames_from_iterator(
audio_chunks(),
in_sample_rate=self._source_sample_rate,
context_id=context_id,
):
if first_frame:
await self.stop_ttfb_metrics()
first_frame = False
yield frame
except BadRequestError as exc:
yield ErrorFrame(error=f"TTS request failed: {exc}")
except Exception as exc:
yield ErrorFrame(error=f"TTS request failed: {exc}")
def _require_provider(actual: str, expected: str, service_name: str) -> None:
if actual != expected:
raise ValueError(f"Unsupported {service_name} provider {actual!r}; expected {expected!r}")
def _language(value: str | None) -> Language | None:
if value is None:
return None
normalized = value.replace("-", "_").upper()
return getattr(Language, normalized, value)