Add xfyun super tts service
This commit is contained in:
@@ -181,6 +181,8 @@ class TTSConfig:
|
||||
pitch: int = 50
|
||||
timeout_sec: float = 30.0
|
||||
source_sample_rate_hz: int | None = None
|
||||
oral_level: str = "mid"
|
||||
text_aggregation_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -10,11 +10,13 @@ from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService
|
||||
from pipecat.services.tts_service import TextAggregationMode
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig
|
||||
from .fastgpt_llm import FastGPTLLMService, FastGPTLLMSettings
|
||||
from .xfyun_asr import DEFAULT_XFYUN_ASR_URL, XfyunASRService
|
||||
from .xfyun_super_tts import DEFAULT_XFYUN_SUPER_TTS_URL, XfyunSuperTTSService
|
||||
from .xfyun_tts import DEFAULT_XFYUN_TTS_URL, XfyunTTSService
|
||||
|
||||
|
||||
@@ -107,6 +109,30 @@ def create_tts_service(config: TTSConfig, audio: AudioConfig):
|
||||
timeout=config.timeout_sec,
|
||||
)
|
||||
|
||||
if config.provider in ("xfyun_super", "xfyun_super_tts"):
|
||||
source_sample_rate = config.source_sample_rate_hz or 24000
|
||||
if source_sample_rate not in (8000, 16000, 24000):
|
||||
raise ValueError(
|
||||
"Xfyun Super TTS source_sample_rate_hz must be 8000, 16000, or 24000"
|
||||
)
|
||||
text_aggregation_mode = config.text_aggregation_mode or TextAggregationMode.TOKEN
|
||||
return XfyunSuperTTSService(
|
||||
app_id=config.app_id,
|
||||
api_key=config.api_key or "",
|
||||
api_secret=config.api_secret,
|
||||
voice=config.voice,
|
||||
url=config.base_url or DEFAULT_XFYUN_SUPER_TTS_URL,
|
||||
sample_rate=audio.sample_rate_hz,
|
||||
source_sample_rate=source_sample_rate,
|
||||
encoding=config.aue,
|
||||
speed=config.speed,
|
||||
volume=config.volume,
|
||||
pitch=config.pitch,
|
||||
oral_level=config.oral_level,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
open_timeout=config.timeout_sec,
|
||||
)
|
||||
|
||||
_require_provider(config.provider, "openai", "tts")
|
||||
service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService
|
||||
return service_class(
|
||||
|
||||
391
engine/xfyun_super_tts.py
Normal file
391
engine/xfyun_super_tts.py
Normal file
@@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import format_datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.tts_service import TextAggregationMode, WebsocketTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as exc:
|
||||
logger.error(f"Exception: {exc}")
|
||||
logger.error("In order to use Xfyun Super TTS, install the websockets package.")
|
||||
raise Exception(f"Missing module: {exc}") from exc
|
||||
|
||||
from .xfyun_tts import _sanitize_text_for_tts
|
||||
|
||||
|
||||
DEFAULT_XFYUN_SUPER_TTS_URL = "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6"
|
||||
VALID_SAMPLE_RATES = {8000, 16000, 24000}
|
||||
|
||||
|
||||
class XfyunSuperTTSService(WebsocketTTSService):
|
||||
"""iFlytek/Xfyun Super Smart TTS using bidirectional WebSocket streaming.
|
||||
|
||||
The service keeps one Xfyun synthesis session open for a Pipecat turn. Each
|
||||
``run_tts`` call sends a text segment with status 0/1, while ``flush_audio``
|
||||
sends the terminal status 2 frame. Audio arrives on the receive task and is
|
||||
appended to the Pipecat audio context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
voice: str,
|
||||
url: str | None = None,
|
||||
sample_rate: int = 16000,
|
||||
source_sample_rate: int = 24000,
|
||||
encoding: str = "raw",
|
||||
speed: int = 50,
|
||||
volume: int = 50,
|
||||
pitch: int = 50,
|
||||
oral_level: str = "mid",
|
||||
text_aggregation_mode: TextAggregationMode | str | None = TextAggregationMode.TOKEN,
|
||||
open_timeout: float = 30.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if isinstance(text_aggregation_mode, str):
|
||||
text_aggregation_mode = TextAggregationMode(text_aggregation_mode)
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=True,
|
||||
push_stop_frames=False,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=sample_rate,
|
||||
settings=TTSSettings(model=None, voice=voice, language=None),
|
||||
**kwargs,
|
||||
)
|
||||
self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "")
|
||||
self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "")
|
||||
self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "")
|
||||
self._voice = voice
|
||||
self._url = url or DEFAULT_XFYUN_SUPER_TTS_URL
|
||||
self._source_sample_rate = source_sample_rate
|
||||
self._encoding = encoding
|
||||
self._speed = speed
|
||||
self._volume = volume
|
||||
self._pitch = pitch
|
||||
self._oral_level = oral_level
|
||||
self._open_timeout = open_timeout
|
||||
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._active_context_id: str | None = None
|
||||
self._started_contexts: set[str] = set()
|
||||
self._seq_by_context: dict[str, int] = {}
|
||||
self._sent_text_bytes_by_context: dict[str, int] = {}
|
||||
self._stream_completed = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame) -> None:
|
||||
await super().start(frame)
|
||||
if not self._app_id or not self._api_key or not self._api_secret:
|
||||
await self.push_error(
|
||||
error_msg="Xfyun Super TTS requires app_id, api_key, and api_secret"
|
||||
)
|
||||
return
|
||||
if self._encoding != "raw":
|
||||
await self.push_error(error_msg="Xfyun Super TTS must use raw PCM audio in Pipecat")
|
||||
return
|
||||
if self._source_sample_rate not in VALID_SAMPLE_RATES:
|
||||
await self.push_error(
|
||||
error_msg=(
|
||||
"Xfyun Super TTS source_sample_rate must be one of "
|
||||
f"{sorted(VALID_SAMPLE_RATES)}"
|
||||
)
|
||||
)
|
||||
return
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self, context_id: str | None = None) -> None:
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
if flush_id not in self._started_contexts:
|
||||
return
|
||||
|
||||
logger.trace(f"{self}: flushing Xfyun Super TTS stream {flush_id}")
|
||||
await self._send_request_frame(flush_id, "", status=2)
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str) -> None:
|
||||
await self.stop_all_metrics()
|
||||
await self._reset_context(context_id)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
await super().on_audio_context_interrupted(context_id)
|
||||
|
||||
async def _connect(self) -> None:
|
||||
await super()._connect()
|
||||
await self._connect_websocket()
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
await super()._disconnect()
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self) -> None:
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
logger.debug("Connecting to Xfyun Super TTS")
|
||||
auth_url = _build_auth_url(self._url, self._api_key, self._api_secret)
|
||||
self._websocket = await websocket_connect(
|
||||
auth_url,
|
||||
max_size=None,
|
||||
open_timeout=self._open_timeout,
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as exc:
|
||||
self._websocket = None
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to Xfyun Super TTS: {exc}",
|
||||
exception=exc,
|
||||
)
|
||||
await self._call_event_handler("on_connection_error", f"{exc}")
|
||||
|
||||
async def _disconnect_websocket(self) -> None:
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Xfyun Super TTS")
|
||||
await self._websocket.close()
|
||||
except Exception as exc:
|
||||
await self.push_error(
|
||||
error_msg=f"Error closing Xfyun Super TTS websocket: {exc}",
|
||||
exception=exc,
|
||||
)
|
||||
finally:
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
self._active_context_id = None
|
||||
self._started_contexts.clear()
|
||||
self._seq_by_context.clear()
|
||||
self._sent_text_bytes_by_context.clear()
|
||||
self._stream_completed = False
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self) -> None:
|
||||
async for raw_message in self._get_websocket():
|
||||
try:
|
||||
message = json.loads(raw_message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"{self}: received non-JSON Xfyun Super TTS message: {raw_message!r}")
|
||||
continue
|
||||
|
||||
header = message.get("header") or {}
|
||||
code = header.get("code", -1)
|
||||
sid = header.get("sid")
|
||||
context_id = self._active_context_id
|
||||
|
||||
if code != 0:
|
||||
error_message = header.get("message", "unknown error")
|
||||
await self.push_error(
|
||||
error_msg=f"Xfyun Super TTS error code={code}, sid={sid}: {error_message}"
|
||||
)
|
||||
if context_id and self.audio_context_available(context_id):
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.remove_audio_context(context_id)
|
||||
if context_id:
|
||||
await self._reset_context(context_id)
|
||||
continue
|
||||
|
||||
audio_obj = (message.get("payload") or {}).get("audio") or {}
|
||||
audio_b64 = audio_obj.get("audio")
|
||||
if audio_b64 and context_id and self.audio_context_available(context_id):
|
||||
await self.stop_ttfb_metrics()
|
||||
audio = base64.b64decode(audio_b64)
|
||||
if self._source_sample_rate != self.sample_rate:
|
||||
audio = await self._resampler.resample(
|
||||
audio, self._source_sample_rate, self.sample_rate
|
||||
)
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
|
||||
audio_status = audio_obj.get("status")
|
||||
header_status = header.get("status")
|
||||
if audio_status == 2 or header_status == 2:
|
||||
if context_id and self.audio_context_available(context_id):
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.remove_audio_context(context_id)
|
||||
if context_id:
|
||||
await self._reset_context(context_id)
|
||||
self._stream_completed = True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
|
||||
sanitized = _sanitize_text_for_tts(text)
|
||||
if not sanitized:
|
||||
return
|
||||
|
||||
if not self._is_streaming_tokens:
|
||||
logger.debug(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
|
||||
else:
|
||||
logger.trace(f"{self}: Generating Xfyun Super TTS [{sanitized}]")
|
||||
|
||||
if self._stream_completed and self._websocket:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if self._active_context_id and self._active_context_id != context_id:
|
||||
yield ErrorFrame(
|
||||
error=(
|
||||
"Xfyun Super TTS supports one active synthesis stream per WebSocket; "
|
||||
f"active={self._active_context_id}, new={context_id}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
status = 0 if context_id not in self._started_contexts else 1
|
||||
await self._send_request_frame(context_id, sanitized, status=status)
|
||||
await self.start_tts_usage_metrics(sanitized)
|
||||
except Exception as exc:
|
||||
yield ErrorFrame(error=f"Xfyun Super TTS request failed: {exc}")
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
yield None
|
||||
|
||||
async def _send_request_frame(self, context_id: str, text: str, *, status: int) -> None:
|
||||
if status == 0:
|
||||
self._active_context_id = context_id
|
||||
self._started_contexts.add(context_id)
|
||||
|
||||
seq = self._seq_by_context.get(context_id, 0)
|
||||
text_bytes = text.encode("utf-8")
|
||||
total_bytes = self._sent_text_bytes_by_context.get(context_id, 0) + len(text_bytes)
|
||||
if total_bytes > 65536:
|
||||
raise ValueError("Xfyun Super TTS text must not exceed 64K UTF-8 bytes per stream")
|
||||
|
||||
frame = self._build_request_frame(text, status=status, seq=seq)
|
||||
await self._get_websocket().send(json.dumps(frame, ensure_ascii=False))
|
||||
|
||||
self._seq_by_context[context_id] = seq + 1
|
||||
self._sent_text_bytes_by_context[context_id] = total_bytes
|
||||
|
||||
def _build_request_frame(self, text: str, *, status: int, seq: int) -> dict[str, Any]:
|
||||
return {
|
||||
"header": {
|
||||
"app_id": self._app_id,
|
||||
"status": status,
|
||||
},
|
||||
"parameter": {
|
||||
"oral": {
|
||||
"oral_level": self._oral_level,
|
||||
},
|
||||
"tts": {
|
||||
"vcn": self._voice,
|
||||
"speed": self._speed,
|
||||
"volume": self._volume,
|
||||
"pitch": self._pitch,
|
||||
"bgs": 0,
|
||||
"reg": 0,
|
||||
"rdn": 0,
|
||||
"rhy": 0,
|
||||
"audio": {
|
||||
"encoding": self._encoding,
|
||||
"sample_rate": self._source_sample_rate,
|
||||
"channels": 1,
|
||||
"bit_depth": 16,
|
||||
"frame_size": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
"payload": {
|
||||
"text": {
|
||||
"encoding": "utf8",
|
||||
"compress": "raw",
|
||||
"format": "plain",
|
||||
"status": status,
|
||||
"seq": seq,
|
||||
"text": base64.b64encode(text.encode("utf-8")).decode("utf-8"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def _reset_context(self, context_id: str) -> None:
|
||||
self._started_contexts.discard(context_id)
|
||||
self._seq_by_context.pop(context_id, None)
|
||||
self._sent_text_bytes_by_context.pop(context_id, None)
|
||||
if self._active_context_id == context_id:
|
||||
self._active_context_id = None
|
||||
|
||||
|
||||
def _build_auth_url(url: str, api_key: str, api_secret: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"ws", "wss"} or not parsed.hostname:
|
||||
raise ValueError(f"invalid Xfyun Super TTS WebSocket URL: {url}")
|
||||
|
||||
host = parsed.hostname
|
||||
path = parsed.path or "/"
|
||||
date = format_datetime(datetime.now(timezone.utc), usegmt=True)
|
||||
request_line = f"GET {path} HTTP/1.1"
|
||||
signature_origin = f"host: {host}\ndate: {date}\n{request_line}"
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
signature = base64.b64encode(signature_sha).decode("utf-8")
|
||||
authorization_origin = (
|
||||
f'api_key="{api_key}", algorithm="hmac-sha256", '
|
||||
f'headers="host date request-line", signature="{signature}"'
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8")
|
||||
query = urlencode({"authorization": authorization, "date": date, "host": host})
|
||||
return f"{url}?{query}"
|
||||
Reference in New Issue
Block a user