Refactor project structure and enhance backend integration
- Expanded package inclusion in `pyproject.toml` to support new modules. - Introduced new `adapters` and `protocol` packages for better organization. - Added backend adapter implementations for control plane integration. - Updated main application imports to reflect new package structure. - Removed deprecated core components and adjusted documentation accordingly. - Enhanced architecture documentation to clarify the new runtime and integration layers.
This commit is contained in:
1
engine/providers/tts/__init__.py
Normal file
1
engine/providers/tts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""TTS providers."""
|
||||
352
engine/providers/tts/dashscope.py
Normal file
352
engine/providers/tts/dashscope.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""DashScope realtime TTS service.
|
||||
|
||||
Implements DashScope's Qwen realtime TTS protocol via the official SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import audioop
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
dashscope = None # type: ignore[assignment]
|
||||
AudioFormat = None # type: ignore[assignment]
|
||||
QwenTtsRealtime = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
|
||||
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _RealtimeEventCallback(QwenTtsRealtimeCallback):
|
||||
"""Bridge SDK callback events into an asyncio queue."""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"):
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._queue = queue
|
||||
|
||||
def _push(self, event: Dict[str, Any]) -> None:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._push({"type": "session.open"})
|
||||
|
||||
def on_close(self, code: int, reason: str) -> None:
|
||||
self._push({"type": "__close__", "code": code, "reason": reason})
|
||||
|
||||
def on_error(self, message: str) -> None:
|
||||
self._push({"type": "error", "error": {"message": str(message)}})
|
||||
|
||||
def on_event(self, event: Any) -> None:
|
||||
if isinstance(event, dict):
|
||||
payload = event
|
||||
elif isinstance(event, str):
|
||||
try:
|
||||
payload = json.loads(event)
|
||||
except json.JSONDecodeError:
|
||||
payload = {"type": "raw", "message": event}
|
||||
else:
|
||||
payload = {"type": "raw", "message": str(event)}
|
||||
self._push(payload)
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
# Some SDK versions provide audio via on_data directly.
|
||||
self._push({"type": "response.audio.delta.raw", "audio": data})
|
||||
|
||||
|
||||
class DashScopeTTSService(BaseTTSService):
|
||||
"""DashScope realtime TTS service using Qwen Realtime protocol."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
DEFAULT_MODEL = "qwen3-tts-flash-realtime"
|
||||
PROVIDER_SAMPLE_RATE = 24000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "Cherry",
|
||||
model: Optional[str] = None,
|
||||
mode: str = "commit",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self.api_key = api_key
|
||||
self.api_url = (
|
||||
api_url
|
||||
or os.getenv("DASHSCOPE_TTS_API_URL")
|
||||
or os.getenv("TTS_API_URL")
|
||||
or self.DEFAULT_WS_URL
|
||||
)
|
||||
self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL
|
||||
|
||||
normalized_mode = str(mode or "").strip().lower()
|
||||
if normalized_mode not in {"server_commit", "commit"}:
|
||||
logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit")
|
||||
normalized_mode = "server_commit"
|
||||
self.mode = normalized_mode
|
||||
|
||||
self._client: Optional[Any] = None
|
||||
self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
||||
self._callback: Optional[_RealtimeEventCallback] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
raise RuntimeError("dashscope package not installed; install with `pip install dashscope`")
|
||||
if not self.api_key:
|
||||
raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue)
|
||||
# The official Python SDK docs set key via global `dashscope.api_key`;
|
||||
# some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor.
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = self.api_key
|
||||
self._client = self._create_realtime_client(self._callback)
|
||||
|
||||
await asyncio.to_thread(self._client.connect)
|
||||
await asyncio.to_thread(
|
||||
self._client.update_session,
|
||||
voice=self.voice,
|
||||
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
|
||||
mode=self.mode,
|
||||
)
|
||||
await self._wait_for_session_ready()
|
||||
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"DashScope realtime TTS service ready: "
|
||||
f"voice={self.voice}, model={self.model}, mode={self.mode}"
|
||||
)
|
||||
|
||||
def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any:
|
||||
init_kwargs = {
|
||||
"model": self.model,
|
||||
"callback": callback,
|
||||
"url": self.api_url,
|
||||
}
|
||||
try:
|
||||
return QwenTtsRealtime( # type: ignore[misc]
|
||||
api_key=self.api_key,
|
||||
**init_kwargs,
|
||||
)
|
||||
except TypeError as exc:
|
||||
if "api_key" not in str(exc):
|
||||
raise
|
||||
logger.debug(
|
||||
"QwenTtsRealtime does not support `api_key` ctor arg; "
|
||||
"falling back to global dashscope.api_key auth"
|
||||
)
|
||||
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self._client:
|
||||
close_fn = getattr(self._client, "close", None)
|
||||
if callable(close_fn):
|
||||
await asyncio.to_thread(close_fn)
|
||||
self._client = None
|
||||
self._drain_event_queue()
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("DashScope realtime TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
audio = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio += chunk.audio
|
||||
return audio
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope TTS service not connected")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._synthesis_lock:
|
||||
self._cancel_event.clear()
|
||||
self._drain_event_queue()
|
||||
|
||||
await self._clear_appended_text()
|
||||
await asyncio.to_thread(self._client.append_text, text)
|
||||
if self.mode == "commit":
|
||||
await asyncio.to_thread(self._client.commit)
|
||||
|
||||
chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms
|
||||
buffer = b""
|
||||
pending_chunk: Optional[bytes] = None
|
||||
resample_state: Any = None
|
||||
|
||||
while True:
|
||||
timeout = 8.0 if self._cancel_event.is_set() else 20.0
|
||||
event = await self._next_event(timeout=timeout)
|
||||
event_type = str(event.get("type") or "").strip()
|
||||
|
||||
if event_type in {"response.audio.delta", "response.audio.delta.raw"}:
|
||||
if self._cancel_event.is_set():
|
||||
continue
|
||||
|
||||
pcm = self._decode_audio_event(event)
|
||||
if not pcm:
|
||||
continue
|
||||
pcm, resample_state = self._resample_if_needed(pcm, resample_state)
|
||||
if not pcm:
|
||||
continue
|
||||
|
||||
buffer += pcm
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False,
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.done":
|
||||
break
|
||||
|
||||
if event_type == "error":
|
||||
raise RuntimeError(self._format_error_event(event))
|
||||
|
||||
if event_type == "__close__":
|
||||
reason = str(event.get("reason") or "unknown")
|
||||
raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}")
|
||||
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self.mode == "commit":
|
||||
await self._clear_appended_text()
|
||||
return
|
||||
|
||||
if not self._client:
|
||||
return
|
||||
cancel_fn = (
|
||||
getattr(self._client, "cancel_response", None)
|
||||
or getattr(self._client, "cancel", None)
|
||||
)
|
||||
if callable(cancel_fn):
|
||||
try:
|
||||
await asyncio.to_thread(cancel_fn)
|
||||
except Exception as exc:
|
||||
logger.debug(f"DashScope cancel failed: {exc}")
|
||||
|
||||
async def _wait_for_session_ready(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
event = await self._next_event(timeout=8.0)
|
||||
event_type = str(event.get("type") or "").strip()
|
||||
if event_type in {"session.updated", "session.open"}:
|
||||
return
|
||||
if event_type == "error":
|
||||
raise RuntimeError(self._format_error_event(event))
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope session update event timeout; continuing with active websocket")
|
||||
|
||||
async def _clear_appended_text(self) -> None:
|
||||
if self.mode != "commit":
|
||||
return
|
||||
if not self._client:
|
||||
return
|
||||
clear_fn = getattr(self._client, "clear_appended_text", None)
|
||||
if callable(clear_fn):
|
||||
try:
|
||||
await asyncio.to_thread(clear_fn)
|
||||
except Exception as exc:
|
||||
logger.debug(f"DashScope clear_appended_text failed: {exc}")
|
||||
|
||||
async def _next_event(self, timeout: float) -> Dict[str, Any]:
|
||||
event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout)
|
||||
if isinstance(event, dict):
|
||||
return event
|
||||
return {"type": "raw", "message": str(event)}
|
||||
|
||||
def _drain_event_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
self._event_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
def _decode_audio_event(self, event: Dict[str, Any]) -> bytes:
|
||||
event_type = str(event.get("type") or "")
|
||||
if event_type == "response.audio.delta.raw":
|
||||
audio = event.get("audio")
|
||||
if isinstance(audio, (bytes, bytearray)):
|
||||
return bytes(audio)
|
||||
return b""
|
||||
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str):
|
||||
try:
|
||||
return base64.b64decode(delta)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to decode DashScope audio delta: {exc}")
|
||||
return b""
|
||||
if isinstance(delta, (bytes, bytearray)):
|
||||
return bytes(delta)
|
||||
return b""
|
||||
|
||||
def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]:
|
||||
if self.sample_rate == self.PROVIDER_SAMPLE_RATE:
|
||||
return pcm, state
|
||||
try:
|
||||
converted, next_state = audioop.ratecv(
|
||||
pcm,
|
||||
2, # 16-bit PCM
|
||||
1, # mono
|
||||
self.PROVIDER_SAMPLE_RATE,
|
||||
self.sample_rate,
|
||||
state,
|
||||
)
|
||||
return converted, next_state
|
||||
except Exception as exc:
|
||||
logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data")
|
||||
return pcm, state
|
||||
|
||||
@staticmethod
|
||||
def _format_error_event(event: Dict[str, Any]) -> str:
|
||||
err = event.get("error")
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or "").strip()
|
||||
message = str(err.get("message") or "").strip()
|
||||
if code and message:
|
||||
return f"{code}: {message}"
|
||||
return message or str(err)
|
||||
return str(err or "DashScope realtime TTS error")
|
||||
49
engine/providers/tts/mock.py
Normal file
49
engine/providers/tts/mock.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""TTS service implementations used by the engine runtime."""
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncIterator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, TTSChunk, ServiceState
|
||||
|
||||
|
||||
class MockTTSService(BaseTTSService):
|
||||
"""Mock TTS service for tests and no-provider fallback."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "mock",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock TTS service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Generate silence based on text length."""
|
||||
word_count = len(text.split())
|
||||
duration_ms = word_count * 100
|
||||
samples = int(self.sample_rate * duration_ms / 1000)
|
||||
return bytes(samples * 2)
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""Generate silence chunks to emulate streaming synthesis."""
|
||||
audio = await self.synthesize(text)
|
||||
|
||||
chunk_size = self.sample_rate * 2 // 10
|
||||
for i in range(0, len(audio), chunk_size):
|
||||
chunk_data = audio[i : i + chunk_size]
|
||||
yield TTSChunk(
|
||||
audio=chunk_data,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=(i + chunk_size >= len(audio)),
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
352
engine/providers/tts/openai_compatible.py
Normal file
352
engine/providers/tts/openai_compatible.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""OpenAI-compatible TTS Service with streaming support.
|
||||
|
||||
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
|
||||
text-to-speech synthesis with streaming.
|
||||
|
||||
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import AsyncIterator, Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, TTSChunk, ServiceState
|
||||
from providers.tts.streaming_adapter import StreamingTTSAdapter # backward-compatible re-export
|
||||
|
||||
|
||||
class OpenAICompatibleTTSService(BaseTTSService):
|
||||
"""
|
||||
OpenAI-compatible TTS service with streaming support.
|
||||
|
||||
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
|
||||
"""
|
||||
|
||||
# Available voices
|
||||
VOICES = {
|
||||
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI-compatible TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key
|
||||
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
speed: Speech speed (0.25 to 4.0)
|
||||
"""
|
||||
# Resolve voice name (case-insensitive), and normalize "model:VoiceId" suffix.
|
||||
resolved_voice = (voice or "").strip()
|
||||
voice_lookup = resolved_voice.lower()
|
||||
if voice_lookup in self.VOICES:
|
||||
full_voice = self.VOICES[voice_lookup]
|
||||
elif ":" in resolved_voice:
|
||||
model_part, voice_part = resolved_voice.split(":", 1)
|
||||
normalized_voice_part = voice_part.strip().lower()
|
||||
if normalized_voice_part in self.VOICES:
|
||||
full_voice = f"{(model_part or model).strip()}:{normalized_voice_part}"
|
||||
else:
|
||||
full_voice = resolved_voice
|
||||
else:
|
||||
full_voice = resolved_voice
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
raw_api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech"
|
||||
self.api_url = self._resolve_speech_endpoint(raw_api_url)
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_speech_endpoint(api_url: str) -> str:
|
||||
"""
|
||||
Accept either:
|
||||
- base URL: https://host/v1
|
||||
- full speech endpoint: https://host/v1/audio/speech
|
||||
and always return the final speech endpoint URL.
|
||||
"""
|
||||
raw = str(api_url or "").strip()
|
||||
if not raw:
|
||||
return "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
parsed = urlparse(raw)
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path.endswith("/audio/speech"):
|
||||
return raw
|
||||
|
||||
if not path:
|
||||
new_path = "/audio/speech"
|
||||
else:
|
||||
new_path = f"{path}/audio/speech"
|
||||
|
||||
return urlunparse(parsed._replace(path=new_path))
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("SiliconFlow TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize complete audio for text."""
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not self._session:
|
||||
raise RuntimeError("TTS service not connected")
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": text,
|
||||
"voice": self.voice,
|
||||
"response_format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"stream": True,
|
||||
"speed": self.speed
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||
return
|
||||
|
||||
# Stream audio chunks
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
buffer = b""
|
||||
pending_chunk = None
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
buffer += chunk
|
||||
|
||||
# Yield complete chunks
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
# Keep one full chunk buffered so we can always tag the true
|
||||
# last full chunk as final when stream length is an exact multiple.
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
|
||||
# Flush pending chunk(s) and remaining tail.
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
# Sentence delimiters
|
||||
SENTENCE_ENDS = {',', '。', '!', '?', '.', '!', '?', '\n'}
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
def _is_non_sentence_period(self, text: str, idx: int) -> bool:
|
||||
"""Check whether '.' should NOT be treated as a sentence delimiter."""
|
||||
if text[idx] != ".":
|
||||
return False
|
||||
|
||||
# Decimal/version segment: 1.2, v1.2.3
|
||||
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
|
||||
return True
|
||||
|
||||
# Number abbreviations: No.1 / No. 1
|
||||
left_start = idx - 1
|
||||
while left_start >= 0 and text[left_start].isalpha():
|
||||
left_start -= 1
|
||||
left_token = text[left_start + 1:idx].lower()
|
||||
if left_token == "no":
|
||||
j = idx + 1
|
||||
while j < len(text) and text[j].isspace():
|
||||
j += 1
|
||||
if j < len(text) and text[j].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
while True:
|
||||
split_idx = -1
|
||||
for i, char in enumerate(self._buffer):
|
||||
if char == "." and self._is_non_sentence_period(self._buffer, i):
|
||||
continue
|
||||
if char in self.SENTENCE_ENDS:
|
||||
split_idx = i
|
||||
break
|
||||
if split_idx < 0:
|
||||
break
|
||||
|
||||
end_idx = split_idx + 1
|
||||
while end_idx < len(self._buffer) and self._buffer[end_idx] in self.SENTENCE_ENDS:
|
||||
end_idx += 1
|
||||
|
||||
sentence = self._buffer[:end_idx].strip()
|
||||
self._buffer = self._buffer[end_idx:]
|
||||
|
||||
if sentence and any(ch.isalnum() for ch in sentence):
|
||||
await self._speak_sentence(sentence)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowTTSService = OpenAICompatibleTTSService
|
||||
8
engine/providers/tts/siliconflow.py
Normal file
8
engine/providers/tts/siliconflow.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Backward-compatible imports for legacy siliconflow_tts module."""
|
||||
|
||||
from providers.tts.openai_compatible import OpenAICompatibleTTSService, StreamingTTSAdapter
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowTTSService = OpenAICompatibleTTSService
|
||||
|
||||
__all__ = ["OpenAICompatibleTTSService", "SiliconFlowTTSService", "StreamingTTSAdapter"]
|
||||
95
engine/providers/tts/streaming_adapter.py
Normal file
95
engine/providers/tts/streaming_adapter.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Backend-agnostic streaming adapter from LLM text to TTS audio."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService
|
||||
from providers.common.streaming_text import extract_tts_sentence, has_spoken_content
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
SENTENCE_ENDS = {"。", "!", "?", ".", "!", "?", "\n"}
|
||||
SENTENCE_CLOSERS = frozenset()
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
self._buffer,
|
||||
end_chars=frozenset(self.SENTENCE_ENDS),
|
||||
trailing_chars=frozenset(self.SENTENCE_ENDS),
|
||||
closers=self.SENTENCE_CLOSERS,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
break
|
||||
|
||||
sentence, self._buffer = split_result
|
||||
if sentence and has_spoken_content(sentence):
|
||||
await self._speak_sentence(sentence)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
Reference in New Issue
Block a user