221 lines
8.2 KiB
Python
221 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
|
|
from openai import BadRequestError
|
|
from openai import NOT_GIVEN
|
|
|
|
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
|
|
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
|
from pipecat.services.openai.llm import OpenAILLMService
|
|
from pipecat.services.openai.stt import OpenAISTTService
|
|
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
|
|
from pipecat.services.tts_service import TextAggregationMode
|
|
from pipecat.transcriptions.language import Language
|
|
|
|
from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig
|
|
from .fastgpt_llm import FastGPTLLMService, FastGPTLLMSettings
|
|
from .xfyun_asr import DEFAULT_XFYUN_ASR_URL, XfyunASRService
|
|
from .xfyun_super_tts import DEFAULT_XFYUN_SUPER_TTS_URL, XfyunSuperTTSService
|
|
from .xfyun_tts import DEFAULT_XFYUN_TTS_URL, XfyunTTSService
|
|
|
|
|
|
def create_stt_service(config: STTConfig, audio: AudioConfig | None = None):
|
|
if config.provider == "xfyun":
|
|
sample_rate = audio.sample_rate_hz if audio else 16000
|
|
return XfyunASRService(
|
|
app_id=config.app_id,
|
|
api_key=config.api_key or "",
|
|
api_secret=config.api_secret,
|
|
url=config.base_url or DEFAULT_XFYUN_ASR_URL,
|
|
language=config.language or "zh_cn",
|
|
domain=config.domain,
|
|
accent=config.accent,
|
|
sample_rate=sample_rate,
|
|
encoding=config.encoding,
|
|
frame_size=config.frame_size,
|
|
open_timeout=config.timeout_sec,
|
|
dynamic_correction=config.dynamic_correction,
|
|
)
|
|
|
|
_require_provider(config.provider, "openai", "stt")
|
|
return OpenAISTTService(
|
|
api_key=config.api_key or None,
|
|
base_url=config.base_url,
|
|
settings=OpenAISTTService.Settings(
|
|
model=config.model,
|
|
language=_language(config.language),
|
|
),
|
|
)
|
|
|
|
|
|
def create_llm_service(
|
|
config: LLMConfig,
|
|
*,
|
|
chat_id: str | None = None,
|
|
session_variables: dict | None = None,
|
|
greeting_prompt: str | None = None,
|
|
):
|
|
if config.is_fastgpt:
|
|
variables = {**config.variables, **(session_variables or {})}
|
|
return FastGPTLLMService(
|
|
api_key=config.api_key,
|
|
base_url=config.base_url or "http://localhost:3000",
|
|
chat_id=chat_id,
|
|
app_id=config.app_id,
|
|
greeting_prompt=greeting_prompt,
|
|
timeout=config.timeout_sec,
|
|
image_input_mode=config.image_input_mode,
|
|
settings=FastGPTLLMSettings(
|
|
model=config.model or "fastgpt",
|
|
variables=variables,
|
|
detail=config.detail,
|
|
),
|
|
)
|
|
|
|
if not config.is_openai:
|
|
supported = ", ".join(sorted(("openai", "fastgpt", "llm")))
|
|
raise ValueError(
|
|
f"Unsupported llm provider {config.provider!r}; expected one of: {supported}"
|
|
)
|
|
return OpenAILLMService(
|
|
api_key=config.api_key or None,
|
|
base_url=config.base_url,
|
|
settings=OpenAILLMService.Settings(
|
|
model=config.model,
|
|
temperature=config.temperature if config.temperature is not None else NOT_GIVEN,
|
|
),
|
|
)
|
|
|
|
|
|
def create_tts_service(config: TTSConfig, audio: AudioConfig):
|
|
if config.provider == "xfyun":
|
|
source_sample_rate = config.source_sample_rate_hz or audio.sample_rate_hz
|
|
if source_sample_rate not in (8000, 16000):
|
|
raise ValueError("Xfyun TTS source_sample_rate_hz must be 8000 or 16000")
|
|
return XfyunTTSService(
|
|
app_id=config.app_id,
|
|
api_key=config.api_key or "",
|
|
api_secret=config.api_secret,
|
|
voice=config.voice,
|
|
url=config.base_url or DEFAULT_XFYUN_TTS_URL,
|
|
sample_rate=audio.sample_rate_hz,
|
|
source_sample_rate=source_sample_rate,
|
|
encoding=config.aue,
|
|
text_encoding=config.tte,
|
|
speed=config.speed,
|
|
volume=config.volume,
|
|
pitch=config.pitch,
|
|
timeout=config.timeout_sec,
|
|
push_stop_frames=True,
|
|
)
|
|
|
|
if config.provider in ("xfyun_super", "xfyun_super_tts"):
|
|
source_sample_rate = config.source_sample_rate_hz or 24000
|
|
if source_sample_rate not in (8000, 16000, 24000):
|
|
raise ValueError(
|
|
"Xfyun Super TTS source_sample_rate_hz must be 8000, 16000, or 24000"
|
|
)
|
|
text_aggregation_mode = config.text_aggregation_mode or TextAggregationMode.TOKEN
|
|
return XfyunSuperTTSService(
|
|
app_id=config.app_id,
|
|
api_key=config.api_key or "",
|
|
api_secret=config.api_secret,
|
|
voice=config.voice,
|
|
url=config.base_url or DEFAULT_XFYUN_SUPER_TTS_URL,
|
|
sample_rate=audio.sample_rate_hz,
|
|
source_sample_rate=source_sample_rate,
|
|
encoding=config.aue,
|
|
speed=config.speed,
|
|
volume=config.volume,
|
|
pitch=config.pitch,
|
|
oral_level=config.oral_level,
|
|
text_aggregation_mode=text_aggregation_mode,
|
|
open_timeout=config.timeout_sec,
|
|
)
|
|
|
|
_require_provider(config.provider, "openai", "tts")
|
|
service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService
|
|
return service_class(
|
|
api_key=config.api_key or None,
|
|
base_url=config.base_url,
|
|
sample_rate=audio.sample_rate_hz,
|
|
source_sample_rate=config.source_sample_rate_hz,
|
|
settings=OpenAITTSService.Settings(
|
|
model=config.model,
|
|
voice=config.voice,
|
|
),
|
|
)
|
|
|
|
|
|
class OpenAICompatibleTTSService(OpenAITTSService):
|
|
"""OpenAI-compatible TTS service that permits provider-specific voice ids."""
|
|
|
|
def __init__(self, *, source_sample_rate: int | None = None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._source_sample_rate = source_sample_rate or OPENAI_SAMPLE_RATE
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
voice = self._settings.voice
|
|
if not voice:
|
|
yield ErrorFrame(error="TTS voice must be specified")
|
|
return
|
|
|
|
try:
|
|
create_params = {
|
|
"input": text,
|
|
"model": self._settings.model,
|
|
"voice": voice,
|
|
"response_format": "pcm",
|
|
}
|
|
|
|
if self._settings.instructions:
|
|
create_params["instructions"] = self._settings.instructions
|
|
|
|
if self._settings.speed:
|
|
create_params["speed"] = self._settings.speed
|
|
|
|
async with self._client.audio.speech.with_streaming_response.create(
|
|
**create_params
|
|
) as response:
|
|
if response.status_code != 200:
|
|
error = await response.text()
|
|
yield ErrorFrame(
|
|
error=f"TTS request failed (status: {response.status_code}, error: {error})"
|
|
)
|
|
return
|
|
|
|
await self.start_tts_usage_metrics(text)
|
|
|
|
async def audio_chunks():
|
|
async for chunk in response.iter_bytes(self.chunk_size):
|
|
if chunk:
|
|
yield chunk
|
|
|
|
first_frame = True
|
|
async for frame in self._stream_audio_frames_from_iterator(
|
|
audio_chunks(),
|
|
in_sample_rate=self._source_sample_rate,
|
|
context_id=context_id,
|
|
):
|
|
if first_frame:
|
|
await self.stop_ttfb_metrics()
|
|
first_frame = False
|
|
yield frame
|
|
except BadRequestError as exc:
|
|
yield ErrorFrame(error=f"TTS request failed: {exc}")
|
|
except Exception as exc:
|
|
yield ErrorFrame(error=f"TTS request failed: {exc}")
|
|
|
|
|
|
def _require_provider(actual: str, expected: str, service_name: str) -> None:
|
|
if actual != expected:
|
|
raise ValueError(f"Unsupported {service_name} provider {actual!r}; expected {expected!r}")
|
|
|
|
|
|
def _language(value: str | None) -> Language | None:
|
|
if value is None:
|
|
return None
|
|
normalized = value.replace("-", "_").upper()
|
|
return getattr(Language, normalized, value)
|