258 lines
9.1 KiB
Python
258 lines
9.1 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import os
|
|
import re
|
|
import unicodedata
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
from datetime import datetime, timezone
|
|
from email.utils import format_datetime
|
|
from typing import Any
|
|
from urllib.parse import urlencode, urlparse
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.frames.frames import ErrorFrame, Frame
|
|
from pipecat.services.settings import TTSSettings
|
|
from pipecat.services.tts_service import TTSService
|
|
from websockets.asyncio.client import connect
|
|
|
|
|
|
DEFAULT_XFYUN_TTS_URL = "wss://tts-api.xfyun.cn/v2/tts"
|
|
|
|
# Strip characters Xfyun's online TTS cannot synthesize. The engine silently
|
|
# rejects (or returns empty audio for) text containing emoji and other
|
|
# non-BMP symbols, which surfaces as "request finished without audio data".
|
|
_EMOJI_AND_SYMBOL_RE = re.compile(
|
|
"["
|
|
"\U0001F300-\U0001FAFF" # misc pictographs, emoji, symbols, transport, etc.
|
|
"\U00002600-\U000027BF" # misc symbols and dingbats
|
|
"\U0001F1E6-\U0001F1FF" # regional indicators (flags)
|
|
"\uFE00-\uFE0F" # variation selectors
|
|
"\u200D" # zero-width joiner
|
|
"]",
|
|
flags=re.UNICODE,
|
|
)
|
|
|
|
|
|
class XfyunTTSService(TTSService):
|
|
"""iFlytek/Xfyun online TTS service for Pipecat.
|
|
|
|
Xfyun's API is not OpenAI-compatible. It uses a signed WebSocket URL,
|
|
receives one JSON request per synthesis, and streams text WebSocket
|
|
messages containing base64-encoded audio chunks. This service requests
|
|
raw PCM so the chunks can become Pipecat audio frames without MP3 decode.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
app_id: str,
|
|
api_key: str,
|
|
api_secret: str,
|
|
voice: str,
|
|
url: str | None = None,
|
|
sample_rate: int = 16000,
|
|
source_sample_rate: int = 16000,
|
|
encoding: str = "raw",
|
|
text_encoding: str = "UTF8",
|
|
speed: int = 50,
|
|
volume: int = 50,
|
|
pitch: int = 50,
|
|
timeout: float = 30.0,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(
|
|
sample_rate=sample_rate,
|
|
settings=TTSSettings(model=None, voice=voice, language=None),
|
|
**kwargs,
|
|
)
|
|
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
|
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
|
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
|
self._voice = voice
|
|
self._url = url or DEFAULT_XFYUN_TTS_URL
|
|
self._source_sample_rate = source_sample_rate
|
|
self._encoding = encoding
|
|
self._text_encoding = text_encoding
|
|
self._speed = speed
|
|
self._volume = volume
|
|
self._pitch = pitch
|
|
self._timeout = timeout
|
|
self._last_failure_detail: str | None = None
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
if not text:
|
|
return
|
|
|
|
if not self._app_id or not self._api_key or not self._api_secret:
|
|
yield ErrorFrame(error="Xfyun TTS requires app_id, api_key, and api_secret")
|
|
return
|
|
|
|
sanitized = _sanitize_text_for_tts(text)
|
|
if not sanitized:
|
|
logger.debug(
|
|
f"{self}: skipping Xfyun TTS, text became empty after sanitization "
|
|
f"(original={text!r})"
|
|
)
|
|
return
|
|
|
|
if sanitized != text:
|
|
logger.debug(
|
|
f"{self}: sanitized Xfyun TTS text "
|
|
f"(original={text!r}, sanitized={sanitized!r})"
|
|
)
|
|
|
|
if len(sanitized.encode("utf-8")) >= 8000:
|
|
yield ErrorFrame(error="Xfyun TTS text must be less than 8000 UTF-8 bytes")
|
|
return
|
|
|
|
if self._encoding != "raw":
|
|
yield ErrorFrame(error="Xfyun TTS is configured for PCM output; set aue/encoding to raw")
|
|
return
|
|
|
|
try:
|
|
await self.start_tts_usage_metrics(sanitized)
|
|
|
|
first_frame = True
|
|
async for frame in self._stream_audio_frames_from_iterator(
|
|
self._iter_audio_chunks(sanitized),
|
|
in_sample_rate=self._source_sample_rate,
|
|
context_id=context_id,
|
|
):
|
|
if first_frame:
|
|
await self.stop_ttfb_metrics()
|
|
first_frame = False
|
|
yield frame
|
|
|
|
if first_frame:
|
|
detail = self._last_failure_detail or "no audio frames received"
|
|
yield ErrorFrame(
|
|
error=(
|
|
f"Xfyun TTS request finished without audio data ({detail}); "
|
|
f"text={sanitized!r}"
|
|
)
|
|
)
|
|
except Exception as exc:
|
|
yield ErrorFrame(error=f"Xfyun TTS request failed: {exc}")
|
|
|
|
async def _iter_audio_chunks(self, text: str) -> AsyncIterator[bytes]:
|
|
request = self._build_request_frame(text)
|
|
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
|
|
|
self._last_failure_detail = None
|
|
frames_received = 0
|
|
audio_bytes_received = 0
|
|
last_status: int | None = None
|
|
last_sid: str | None = None
|
|
saw_status_2 = False
|
|
|
|
async with connect(auth_url, max_size=None, open_timeout=self._timeout) as websocket:
|
|
await websocket.send(json.dumps(request, ensure_ascii=False))
|
|
|
|
async for raw_message in websocket:
|
|
frames_received += 1
|
|
payload = json.loads(raw_message)
|
|
code = payload.get("code", -1)
|
|
sid = payload.get("sid")
|
|
if sid:
|
|
last_sid = sid
|
|
if code != 0:
|
|
err_msg = payload.get("message", "unknown error")
|
|
raise RuntimeError(f"code={code}, sid={sid}, message={err_msg}")
|
|
|
|
data = payload.get("data")
|
|
if not isinstance(data, dict):
|
|
continue
|
|
|
|
last_status = data.get("status", last_status)
|
|
|
|
audio_b64 = data.get("audio")
|
|
if audio_b64:
|
|
audio_bytes = base64.b64decode(audio_b64)
|
|
audio_bytes_received += len(audio_bytes)
|
|
yield audio_bytes
|
|
|
|
if data.get("status") == 2:
|
|
saw_status_2 = True
|
|
break
|
|
|
|
if audio_bytes_received == 0:
|
|
self._last_failure_detail = (
|
|
f"frames={frames_received}, audio_bytes=0, "
|
|
f"last_status={last_status}, saw_status_2={saw_status_2}, sid={last_sid}"
|
|
)
|
|
logger.warning(
|
|
f"{self}: Xfyun TTS produced no audio ({self._last_failure_detail})"
|
|
)
|
|
|
|
def _build_request_frame(self, text: str) -> dict[str, Any]:
|
|
business: dict[str, Any] = {
|
|
"aue": self._encoding,
|
|
"auf": f"audio/L16;rate={self._source_sample_rate}",
|
|
"vcn": self._voice,
|
|
"speed": self._speed,
|
|
"volume": self._volume,
|
|
"pitch": self._pitch,
|
|
"tte": self._text_encoding,
|
|
}
|
|
|
|
return {
|
|
"common": {"app_id": self._app_id},
|
|
"business": business,
|
|
"data": {
|
|
"status": 2,
|
|
"text": base64.b64encode(text.encode("utf-8")).decode("utf-8"),
|
|
},
|
|
}
|
|
|
|
|
|
def _sanitize_text_for_tts(text: str) -> str:
|
|
"""Strip characters Xfyun's online TTS cannot synthesize.
|
|
|
|
The Xfyun ``/v2/tts`` engine silently drops or rejects emoji, pictographs,
|
|
dingbats, regional-indicator flags, variation selectors, and zero-width
|
|
joiners. When such characters appear in the input the synthesis can
|
|
finish without any audio data ("Xfyun TTS request finished without audio
|
|
data"). We also drop control characters (other than common whitespace)
|
|
and "Symbol, Other" codepoints, then collapse runs of whitespace.
|
|
"""
|
|
if not text:
|
|
return text
|
|
|
|
cleaned = _EMOJI_AND_SYMBOL_RE.sub("", text)
|
|
filtered: list[str] = []
|
|
for ch in cleaned:
|
|
category = unicodedata.category(ch)
|
|
if category == "So":
|
|
continue
|
|
if category.startswith("C") and ch not in ("\n", "\r", "\t"):
|
|
continue
|
|
filtered.append(ch)
|
|
return re.sub(r"\s+", " ", "".join(filtered)).strip()
|
|
|
|
|
|
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
|
parsed = urlparse(url)
|
|
host = parsed.netloc
|
|
path = parsed.path or "/v2/tts"
|
|
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
|
request_line = f"GET {path} HTTP/1.1"
|
|
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
|
signature_sha = hmac.new(
|
|
api_secret.encode("utf-8"),
|
|
signature_origin.encode("utf-8"),
|
|
digestmod=hashlib.sha256,
|
|
).digest()
|
|
signature = base64.b64encode(signature_sha).decode("utf-8")
|
|
authorization_origin = (
|
|
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
|
f'headers="host date request-line", signature="{signature}"'
|
|
)
|
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
|
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
|
return f"{url}?{query}"
|