392 lines
15 KiB
Python
392 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import os
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import datetime, timezone
|
|
from email.utils import format_datetime
|
|
from typing import Any
|
|
from urllib.parse import urlencode, urlparse
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
EndFrame,
|
|
ErrorFrame,
|
|
Frame,
|
|
StartFrame,
|
|
TTSAudioRawFrame,
|
|
TTSStoppedFrame,
|
|
)
|
|
from pipecat.services.settings import TTSSettings
|
|
from pipecat.services.tts_service import TextAggregationMode, WebsocketTTSService
|
|
from pipecat.utils.tracing.service_decorators import traced_tts
|
|
|
|
try:
|
|
from websockets.asyncio.client import connect as websocket_connect
|
|
from websockets.protocol import State
|
|
except ModuleNotFoundError as exc:
|
|
logger.error(f"Exception: {exc}")
|
|
logger.error("In order to use Xfyun Super TTS, install the websockets package.")
|
|
raise Exception(f"Missing module: {exc}") from exc
|
|
|
|
from .xfyun_tts import _sanitize_text_for_tts
|
|
|
|
|
|
DEFAULT_XFYUN_SUPER_TTS_URL = "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6"
|
|
VALID_SAMPLE_RATES = {8000, 16000, 24000}
|
|
|
|
|
|
class XfyunSuperTTSService(WebsocketTTSService):
|
|
"""iFlytek/Xfyun Super Smart TTS using bidirectional WebSocket streaming.
|
|
|
|
The service keeps one Xfyun synthesis session open for a Pipecat turn. Each
|
|
``run_tts`` call sends a text segment with status 0/1, while ``flush_audio``
|
|
sends the terminal status 2 frame. Audio arrives on the receive task and is
|
|
appended to the Pipecat audio context.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
app_id: str,
|
|
api_key: str,
|
|
api_secret: str,
|
|
voice: str,
|
|
url: str | None = None,
|
|
sample_rate: int = 16000,
|
|
source_sample_rate: int = 24000,
|
|
encoding: str = "raw",
|
|
speed: int = 50,
|
|
volume: int = 50,
|
|
pitch: int = 50,
|
|
oral_level: str = "mid",
|
|
text_aggregation_mode: TextAggregationMode | str | None = TextAggregationMode.TOKEN,
|
|
open_timeout: float = 30.0,
|
|
**kwargs,
|
|
) -> None:
|
|
if isinstance(text_aggregation_mode, str):
|
|
text_aggregation_mode = TextAggregationMode(text_aggregation_mode)
|
|
|
|
super().__init__(
|
|
text_aggregation_mode=text_aggregation_mode,
|
|
push_text_frames=True,
|
|
push_stop_frames=False,
|
|
push_start_frame=True,
|
|
pause_frame_processing=False,
|
|
sample_rate=sample_rate,
|
|
settings=TTSSettings(model=None, voice=voice, language=None),
|
|
**kwargs,
|
|
)
|
|
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
|
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
|
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
|
self._voice = voice
|
|
self._url = url or DEFAULT_XFYUN_SUPER_TTS_URL
|
|
self._source_sample_rate = source_sample_rate
|
|
self._encoding = encoding
|
|
self._speed = speed
|
|
self._volume = volume
|
|
self._pitch = pitch
|
|
self._oral_level = oral_level
|
|
self._open_timeout = open_timeout
|
|
|
|
self._receive_task: asyncio.Task | None = None
|
|
self._active_context_id: str | None = None
|
|
self._started_contexts: set[str] = set()
|
|
self._seq_by_context: dict[str, int] = {}
|
|
self._sent_text_bytes_by_context: dict[str, int] = {}
|
|
self._stream_completed = False
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return True
|
|
|
|
async def start(self, frame: StartFrame) -> None:
|
|
await super().start(frame)
|
|
if not self._app_id or not self._api_key or not self._api_secret:
|
|
await self.push_error(
|
|
error_msg="Xfyun Super TTS requires app_id, api_key, and api_secret"
|
|
)
|
|
return
|
|
if self._encoding != "raw":
|
|
await self.push_error(error_msg="Xfyun Super TTS must use raw PCM audio in Pipecat")
|
|
return
|
|
if self._source_sample_rate not in VALID_SAMPLE_RATES:
|
|
await self.push_error(
|
|
error_msg=(
|
|
"Xfyun Super TTS source_sample_rate must be one of "
|
|
f"{sorted(VALID_SAMPLE_RATES)}"
|
|
)
|
|
)
|
|
return
|
|
await self._connect()
|
|
|
|
async def stop(self, frame: EndFrame) -> None:
|
|
await super().stop(frame)
|
|
await self._disconnect()
|
|
|
|
async def cancel(self, frame: CancelFrame) -> None:
|
|
await super().cancel(frame)
|
|
await self._disconnect()
|
|
|
|
async def flush_audio(self, context_id: str | None = None) -> None:
|
|
flush_id = context_id or self.get_active_audio_context_id()
|
|
if not flush_id or not self._websocket:
|
|
return
|
|
if flush_id not in self._started_contexts:
|
|
return
|
|
|
|
logger.trace(f"{self}: flushing Xfyun Super TTS stream {flush_id}")
|
|
await self._send_request_frame(flush_id, "", status=2)
|
|
|
|
async def on_audio_context_interrupted(self, context_id: str) -> None:
|
|
await self.stop_all_metrics()
|
|
await self._reset_context(context_id)
|
|
await self._disconnect()
|
|
await self._connect()
|
|
await super().on_audio_context_interrupted(context_id)
|
|
|
|
async def _connect(self) -> None:
|
|
await super()._connect()
|
|
await self._connect_websocket()
|
|
if self._websocket and not self._receive_task:
|
|
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
|
|
|
async def _disconnect(self) -> None:
|
|
await super()._disconnect()
|
|
if self._receive_task:
|
|
await self.cancel_task(self._receive_task)
|
|
self._receive_task = None
|
|
await self._disconnect_websocket()
|
|
|
|
async def _connect_websocket(self) -> None:
|
|
try:
|
|
if self._websocket and self._websocket.state is State.OPEN:
|
|
return
|
|
logger.debug("Connecting to Xfyun Super TTS")
|
|
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
|
self._websocket = await websocket_connect(
|
|
auth_url,
|
|
max_size=None,
|
|
open_timeout=self._open_timeout,
|
|
)
|
|
await self._call_event_handler("on_connected")
|
|
except Exception as exc:
|
|
self._websocket = None
|
|
await self.push_error(
|
|
error_msg=f"Unable to connect to Xfyun Super TTS: {exc}",
|
|
exception=exc,
|
|
)
|
|
await self._call_event_handler("on_connection_error", f"{exc}")
|
|
|
|
async def _disconnect_websocket(self) -> None:
|
|
try:
|
|
await self.stop_all_metrics()
|
|
if self._websocket:
|
|
logger.debug("Disconnecting from Xfyun Super TTS")
|
|
await self._websocket.close()
|
|
except Exception as exc:
|
|
await self.push_error(
|
|
error_msg=f"Error closing Xfyun Super TTS websocket: {exc}",
|
|
exception=exc,
|
|
)
|
|
finally:
|
|
await self.remove_active_audio_context()
|
|
self._websocket = None
|
|
self._active_context_id = None
|
|
self._started_contexts.clear()
|
|
self._seq_by_context.clear()
|
|
self._sent_text_bytes_by_context.clear()
|
|
self._stream_completed = False
|
|
await self._call_event_handler("on_disconnected")
|
|
|
|
def _get_websocket(self):
|
|
if self._websocket:
|
|
return self._websocket
|
|
raise Exception("Websocket not connected")
|
|
|
|
async def _receive_messages(self) -> None:
|
|
async for raw_message in self._get_websocket():
|
|
try:
|
|
message = json.loads(raw_message)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"{self}: received non-JSON Xfyun Super TTS message: {raw_message!r}")
|
|
continue
|
|
|
|
header = message.get("header") or {}
|
|
code = header.get("code", -1)
|
|
sid = header.get("sid")
|
|
context_id = self._active_context_id
|
|
|
|
if code != 0:
|
|
error_message = header.get("message", "unknown error")
|
|
await self.push_error(
|
|
error_msg=f"Xfyun Super TTS error code={code}, sid={sid}: {error_message}"
|
|
)
|
|
if context_id and self.audio_context_available(context_id):
|
|
await self.append_to_audio_context(
|
|
context_id, TTSStoppedFrame(context_id=context_id)
|
|
)
|
|
await self.remove_audio_context(context_id)
|
|
if context_id:
|
|
await self._reset_context(context_id)
|
|
continue
|
|
|
|
audio_obj = (message.get("payload") or {}).get("audio") or {}
|
|
audio_b64 = audio_obj.get("audio")
|
|
if audio_b64 and context_id and self.audio_context_available(context_id):
|
|
await self.stop_ttfb_metrics()
|
|
audio = base64.b64decode(audio_b64)
|
|
if self._source_sample_rate != self.sample_rate:
|
|
audio = await self._resampler.resample(
|
|
audio, self._source_sample_rate, self.sample_rate
|
|
)
|
|
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id)
|
|
await self.append_to_audio_context(context_id, frame)
|
|
|
|
audio_status = audio_obj.get("status")
|
|
header_status = header.get("status")
|
|
if audio_status == 2 or header_status == 2:
|
|
if context_id and self.audio_context_available(context_id):
|
|
await self.append_to_audio_context(
|
|
context_id, TTSStoppedFrame(context_id=context_id)
|
|
)
|
|
await self.remove_audio_context(context_id)
|
|
if context_id:
|
|
await self._reset_context(context_id)
|
|
self._stream_completed = True
|
|
|
|
@traced_tts
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
|
|
sanitized = _sanitize_text_for_tts(text)
|
|
if not sanitized:
|
|
return
|
|
|
|
if not self._is_streaming_tokens:
|
|
logger.debug(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
|
|
else:
|
|
logger.trace(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
|
|
|
|
if self._stream_completed and self._websocket:
|
|
await self._disconnect()
|
|
await self._connect()
|
|
|
|
if not self._websocket or self._websocket.state is State.CLOSED:
|
|
await self._connect()
|
|
|
|
if self._active_context_id and self._active_context_id != context_id:
|
|
yield ErrorFrame(
|
|
error=(
|
|
"Xfyun Super TTS supports one active synthesis stream per WebSocket; "
|
|
f"active={self._active_context_id}, new={context_id}"
|
|
)
|
|
)
|
|
return
|
|
|
|
try:
|
|
status = 0 if context_id not in self._started_contexts else 1
|
|
await self._send_request_frame(context_id, sanitized, status=status)
|
|
await self.start_tts_usage_metrics(sanitized)
|
|
except Exception as exc:
|
|
yield ErrorFrame(error=f"Xfyun Super TTS request failed: {exc}")
|
|
yield TTSStoppedFrame(context_id=context_id)
|
|
await self._disconnect()
|
|
await self._connect()
|
|
return
|
|
|
|
yield None
|
|
|
|
async def _send_request_frame(self, context_id: str, text: str, *, status: int) -> None:
|
|
if status == 0:
|
|
self._active_context_id = context_id
|
|
self._started_contexts.add(context_id)
|
|
|
|
seq = self._seq_by_context.get(context_id, 0)
|
|
text_bytes = text.encode("utf-8")
|
|
total_bytes = self._sent_text_bytes_by_context.get(context_id, 0) + len(text_bytes)
|
|
if total_bytes > 65536:
|
|
raise ValueError("Xfyun Super TTS text must not exceed 64K UTF-8 bytes per stream")
|
|
|
|
frame = self._build_request_frame(text, status=status, seq=seq)
|
|
await self._get_websocket().send(json.dumps(frame, ensure_ascii=False))
|
|
|
|
self._seq_by_context[context_id] = seq + 1
|
|
self._sent_text_bytes_by_context[context_id] = total_bytes
|
|
|
|
def _build_request_frame(self, text: str, *, status: int, seq: int) -> dict[str, Any]:
|
|
return {
|
|
"header": {
|
|
"app_id": self._app_id,
|
|
"status": status,
|
|
},
|
|
"parameter": {
|
|
"oral": {
|
|
"oral_level": self._oral_level,
|
|
},
|
|
"tts": {
|
|
"vcn": self._voice,
|
|
"speed": self._speed,
|
|
"volume": self._volume,
|
|
"pitch": self._pitch,
|
|
"bgs": 0,
|
|
"reg": 0,
|
|
"rdn": 0,
|
|
"rhy": 0,
|
|
"audio": {
|
|
"encoding": self._encoding,
|
|
"sample_rate": self._source_sample_rate,
|
|
"channels": 1,
|
|
"bit_depth": 16,
|
|
"frame_size": 0,
|
|
},
|
|
},
|
|
},
|
|
"payload": {
|
|
"text": {
|
|
"encoding": "utf8",
|
|
"compress": "raw",
|
|
"format": "plain",
|
|
"status": status,
|
|
"seq": seq,
|
|
"text": base64.b64encode(text.encode("utf-8")).decode("utf-8"),
|
|
},
|
|
},
|
|
}
|
|
|
|
async def _reset_context(self, context_id: str) -> None:
|
|
self._started_contexts.discard(context_id)
|
|
self._seq_by_context.pop(context_id, None)
|
|
self._sent_text_bytes_by_context.pop(context_id, None)
|
|
if self._active_context_id == context_id:
|
|
self._active_context_id = None
|
|
|
|
|
|
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
|
parsed = urlparse(url)
|
|
if parsed.scheme not in {"ws", "wss"} or not parsed.hostname:
|
|
raise ValueError(f"invalid Xfyun Super TTS WebSocket URL: {url}")
|
|
|
|
host = parsed.hostname
|
|
path = parsed.path or "/"
|
|
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
|
request_line = f"GET {path} HTTP/1.1"
|
|
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
|
signature_sha = hmac.new(
|
|
api_secret.encode("utf-8"),
|
|
signature_origin.encode("utf-8"),
|
|
digestmod=hashlib.sha256,
|
|
).digest()
|
|
signature = base64.b64encode(signature_sha).decode("utf-8")
|
|
authorization_origin = (
|
|
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
|
f'headers="host date request-line", signature="{signature}"'
|
|
)
|
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
|
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
|
return f"{url}?{query}"
|