127 lines
4.5 KiB
Python
127 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
|
|
from openai import BadRequestError
|
|
from openai import NOT_GIVEN
|
|
|
|
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
|
|
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
|
from pipecat.services.openai.llm import OpenAILLMService
|
|
from pipecat.services.openai.stt import OpenAISTTService
|
|
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
|
|
from pipecat.transcriptions.language import Language
|
|
|
|
from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig
|
|
|
|
|
|
def create_stt_service(config: STTConfig):
|
|
_require_provider(config.provider, "openai", "stt")
|
|
return OpenAISTTService(
|
|
api_key=config.api_key or None,
|
|
base_url=config.base_url,
|
|
settings=OpenAISTTService.Settings(
|
|
model=config.model,
|
|
language=_language(config.language),
|
|
),
|
|
)
|
|
|
|
|
|
def create_llm_service(config: LLMConfig):
|
|
_require_provider(config.provider, "openai", "llm")
|
|
return OpenAILLMService(
|
|
api_key=config.api_key or None,
|
|
base_url=config.base_url,
|
|
settings=OpenAILLMService.Settings(
|
|
model=config.model,
|
|
temperature=config.temperature if config.temperature is not None else NOT_GIVEN,
|
|
),
|
|
)
|
|
|
|
|
|
def create_tts_service(config: TTSConfig, audio: AudioConfig):
|
|
_require_provider(config.provider, "openai", "tts")
|
|
service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService
|
|
return service_class(
|
|
api_key=config.api_key or None,
|
|
base_url=config.base_url,
|
|
sample_rate=audio.sample_rate_hz,
|
|
source_sample_rate=config.source_sample_rate_hz,
|
|
settings=OpenAITTSService.Settings(
|
|
model=config.model,
|
|
voice=config.voice,
|
|
),
|
|
)
|
|
|
|
|
|
class OpenAICompatibleTTSService(OpenAITTSService):
|
|
"""OpenAI-compatible TTS service that permits provider-specific voice ids."""
|
|
|
|
def __init__(self, *, source_sample_rate: int | None = None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._source_sample_rate = source_sample_rate or OPENAI_SAMPLE_RATE
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
voice = self._settings.voice
|
|
if not voice:
|
|
yield ErrorFrame(error="TTS voice must be specified")
|
|
return
|
|
|
|
try:
|
|
create_params = {
|
|
"input": text,
|
|
"model": self._settings.model,
|
|
"voice": voice,
|
|
"response_format": "pcm",
|
|
}
|
|
|
|
if self._settings.instructions:
|
|
create_params["instructions"] = self._settings.instructions
|
|
|
|
if self._settings.speed:
|
|
create_params["speed"] = self._settings.speed
|
|
|
|
async with self._client.audio.speech.with_streaming_response.create(
|
|
**create_params
|
|
) as response:
|
|
if response.status_code != 200:
|
|
error = await response.text()
|
|
yield ErrorFrame(
|
|
error=f"TTS request failed (status: {response.status_code}, error: {error})"
|
|
)
|
|
return
|
|
|
|
await self.start_tts_usage_metrics(text)
|
|
|
|
async def audio_chunks():
|
|
async for chunk in response.iter_bytes(self.chunk_size):
|
|
if chunk:
|
|
yield chunk
|
|
|
|
first_frame = True
|
|
async for frame in self._stream_audio_frames_from_iterator(
|
|
audio_chunks(),
|
|
in_sample_rate=self._source_sample_rate,
|
|
context_id=context_id,
|
|
):
|
|
if first_frame:
|
|
await self.stop_ttfb_metrics()
|
|
first_frame = False
|
|
yield frame
|
|
except BadRequestError as exc:
|
|
yield ErrorFrame(error=f"TTS request failed: {exc}")
|
|
except Exception as exc:
|
|
yield ErrorFrame(error=f"TTS request failed: {exc}")
|
|
|
|
|
|
def _require_provider(actual: str, expected: str, service_name: str) -> None:
|
|
if actual != expected:
|
|
raise ValueError(f"Unsupported {service_name} provider {actual!r}; expected {expected!r}")
|
|
|
|
|
|
def _language(value: str | None) -> Language | None:
|
|
if value is None:
|
|
return None
|
|
normalized = value.replace("-", "_").upper()
|
|
return getattr(Language, normalized, value)
|