Files
ai-video-fullstack/backend/services/pipecat/xfyun_super_tts.py
Xin Wang e25dfd4003 Add support for Xfyun ASR and TTS services in the backend
- Introduce new Xfyun ASR and TTS services, enabling integration with iFlytek's voice recognition and synthesis capabilities.
- Update AssistantConfig model to include interface types for STT and TTS.
- Enhance credential testing to validate Xfyun credentials.
- Modify service factory to create Xfyun services based on configuration.
- Update README with new configuration details for Xfyun integration.
- Add new frontend components for visualizing audio streams and managing user interactions.
2026-06-11 10:51:08 +08:00

392 lines
15 KiB
Python

from __future__ import annotations
import asyncio
import base64
import hashlib
import hmac
import json
import os
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from email.utils import format_datetime
from typing import Any
from urllib.parse import urlencode, urlparse
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import TextAggregationMode, WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as exc:
logger.error(f"Exception: {exc}")
logger.error("In order to use Xfyun Super TTS, install the websockets package.")
raise Exception(f"Missing module: {exc}") from exc
from .xfyun_tts import _sanitize_text_for_tts
DEFAULT_XFYUN_SUPER_TTS_URL = "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6"
VALID_SAMPLE_RATES = {8000, 16000, 24000}
class XfyunSuperTTSService(WebsocketTTSService):
"""iFlytek/Xfyun Super Smart TTS using bidirectional WebSocket streaming.
The service keeps one Xfyun synthesis session open for a Pipecat turn. Each
``run_tts`` call sends a text segment with status 0/1, while ``flush_audio``
sends the terminal status 2 frame. Audio arrives on the receive task and is
appended to the Pipecat audio context.
"""
def __init__(
self,
*,
app_id: str,
api_key: str,
api_secret: str,
voice: str,
url: str | None = None,
sample_rate: int = 16000,
source_sample_rate: int = 24000,
encoding: str = "raw",
speed: int = 50,
volume: int = 50,
pitch: int = 50,
oral_level: str = "mid",
text_aggregation_mode: TextAggregationMode | str | None = TextAggregationMode.TOKEN,
open_timeout: float = 30.0,
**kwargs,
) -> None:
if isinstance(text_aggregation_mode, str):
text_aggregation_mode = TextAggregationMode(text_aggregation_mode)
super().__init__(
text_aggregation_mode=text_aggregation_mode,
push_text_frames=True,
push_stop_frames=False,
push_start_frame=True,
pause_frame_processing=False,
sample_rate=sample_rate,
settings=TTSSettings(model=None, voice=voice, language=None),
**kwargs,
)
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
self._voice = voice
self._url = url or DEFAULT_XFYUN_SUPER_TTS_URL
self._source_sample_rate = source_sample_rate
self._encoding = encoding
self._speed = speed
self._volume = volume
self._pitch = pitch
self._oral_level = oral_level
self._open_timeout = open_timeout
self._receive_task: asyncio.Task | None = None
self._active_context_id: str | None = None
self._started_contexts: set[str] = set()
self._seq_by_context: dict[str, int] = {}
self._sent_text_bytes_by_context: dict[str, int] = {}
self._stream_completed = False
def can_generate_metrics(self) -> bool:
return True
async def start(self, frame: StartFrame) -> None:
await super().start(frame)
if not self._app_id or not self._api_key or not self._api_secret:
await self.push_error(
error_msg="Xfyun Super TTS requires app_id, api_key, and api_secret"
)
return
if self._encoding != "raw":
await self.push_error(error_msg="Xfyun Super TTS must use raw PCM audio in Pipecat")
return
if self._source_sample_rate not in VALID_SAMPLE_RATES:
await self.push_error(
error_msg=(
"Xfyun Super TTS source_sample_rate must be one of "
f"{sorted(VALID_SAMPLE_RATES)}"
)
)
return
await self._connect()
async def stop(self, frame: EndFrame) -> None:
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame) -> None:
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self, context_id: str | None = None) -> None:
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
if flush_id not in self._started_contexts:
return
logger.trace(f"{self}: flushing Xfyun Super TTS stream {flush_id}")
await self._send_request_frame(flush_id, "", status=2)
async def on_audio_context_interrupted(self, context_id: str) -> None:
await self.stop_all_metrics()
await self._reset_context(context_id)
await self._disconnect()
await self._connect()
await super().on_audio_context_interrupted(context_id)
async def _connect(self) -> None:
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self) -> None:
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self) -> None:
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to Xfyun Super TTS")
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
self._websocket = await websocket_connect(
auth_url,
max_size=None,
open_timeout=self._open_timeout,
)
await self._call_event_handler("on_connected")
except Exception as exc:
self._websocket = None
await self.push_error(
error_msg=f"Unable to connect to Xfyun Super TTS: {exc}",
exception=exc,
)
await self._call_event_handler("on_connection_error", f"{exc}")
async def _disconnect_websocket(self) -> None:
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Xfyun Super TTS")
await self._websocket.close()
except Exception as exc:
await self.push_error(
error_msg=f"Error closing Xfyun Super TTS websocket: {exc}",
exception=exc,
)
finally:
await self.remove_active_audio_context()
self._websocket = None
self._active_context_id = None
self._started_contexts.clear()
self._seq_by_context.clear()
self._sent_text_bytes_by_context.clear()
self._stream_completed = False
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _receive_messages(self) -> None:
async for raw_message in self._get_websocket():
try:
message = json.loads(raw_message)
except json.JSONDecodeError:
logger.warning(f"{self}: received non-JSON Xfyun Super TTS message: {raw_message!r}")
continue
header = message.get("header") or {}
code = header.get("code", -1)
sid = header.get("sid")
context_id = self._active_context_id
if code != 0:
error_message = header.get("message", "unknown error")
await self.push_error(
error_msg=f"Xfyun Super TTS error code={code}, sid={sid}: {error_message}"
)
if context_id and self.audio_context_available(context_id):
await self.append_to_audio_context(
context_id, TTSStoppedFrame(context_id=context_id)
)
await self.remove_audio_context(context_id)
if context_id:
await self._reset_context(context_id)
continue
audio_obj = (message.get("payload") or {}).get("audio") or {}
audio_b64 = audio_obj.get("audio")
if audio_b64 and context_id and self.audio_context_available(context_id):
await self.stop_ttfb_metrics()
audio = base64.b64decode(audio_b64)
if self._source_sample_rate != self.sample_rate:
audio = await self._resampler.resample(
audio, self._source_sample_rate, self.sample_rate
)
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id)
await self.append_to_audio_context(context_id, frame)
audio_status = audio_obj.get("status")
header_status = header.get("status")
if audio_status == 2 or header_status == 2:
if context_id and self.audio_context_available(context_id):
await self.append_to_audio_context(
context_id, TTSStoppedFrame(context_id=context_id)
)
await self.remove_audio_context(context_id)
if context_id:
await self._reset_context(context_id)
self._stream_completed = True
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
sanitized = _sanitize_text_for_tts(text)
if not sanitized:
return
if not self._is_streaming_tokens:
logger.debug(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
else:
logger.trace(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
if self._stream_completed and self._websocket:
await self._disconnect()
await self._connect()
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if self._active_context_id and self._active_context_id != context_id:
yield ErrorFrame(
error=(
"Xfyun Super TTS supports one active synthesis stream per WebSocket; "
f"active={self._active_context_id}, new={context_id}"
)
)
return
try:
status = 0 if context_id not in self._started_contexts else 1
await self._send_request_frame(context_id, sanitized, status=status)
await self.start_tts_usage_metrics(sanitized)
except Exception as exc:
yield ErrorFrame(error=f"Xfyun Super TTS request failed: {exc}")
yield TTSStoppedFrame(context_id=context_id)
await self._disconnect()
await self._connect()
return
yield None
async def _send_request_frame(self, context_id: str, text: str, *, status: int) -> None:
if status == 0:
self._active_context_id = context_id
self._started_contexts.add(context_id)
seq = self._seq_by_context.get(context_id, 0)
text_bytes = text.encode("utf-8")
total_bytes = self._sent_text_bytes_by_context.get(context_id, 0) + len(text_bytes)
if total_bytes > 65536:
raise ValueError("Xfyun Super TTS text must not exceed 64K UTF-8 bytes per stream")
frame = self._build_request_frame(text, status=status, seq=seq)
await self._get_websocket().send(json.dumps(frame, ensure_ascii=False))
self._seq_by_context[context_id] = seq + 1
self._sent_text_bytes_by_context[context_id] = total_bytes
def _build_request_frame(self, text: str, *, status: int, seq: int) -> dict[str, Any]:
return {
"header": {
"app_id": self._app_id,
"status": status,
},
"parameter": {
"oral": {
"oral_level": self._oral_level,
},
"tts": {
"vcn": self._voice,
"speed": self._speed,
"volume": self._volume,
"pitch": self._pitch,
"bgs": 0,
"reg": 0,
"rdn": 0,
"rhy": 0,
"audio": {
"encoding": self._encoding,
"sample_rate": self._source_sample_rate,
"channels": 1,
"bit_depth": 16,
"frame_size": 0,
},
},
},
"payload": {
"text": {
"encoding": "utf8",
"compress": "raw",
"format": "plain",
"status": status,
"seq": seq,
"text": base64.b64encode(text.encode("utf-8")).decode("utf-8"),
},
},
}
async def _reset_context(self, context_id: str) -> None:
self._started_contexts.discard(context_id)
self._seq_by_context.pop(context_id, None)
self._sent_text_bytes_by_context.pop(context_id, None)
if self._active_context_id == context_id:
self._active_context_id = None
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in {"ws", "wss"} or not parsed.hostname:
raise ValueError(f"invalid Xfyun Super TTS WebSocket URL: {url}")
host = parsed.hostname
path = parsed.path or "/"
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
request_line = f"GET {path} HTTP/1.1"
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
signature_sha = hmac.new(
api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature = base64.b64encode(signature_sha).decode("utf-8")
authorization_origin = (
f'api_key="{api_key}", algorithm="hmac-sha256", '
f'headers="host date request-line", signature="{signature}"'
)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
query = urlencode({"authorization": authorization, "date": date, "host": host})
return f"{url}?{query}"