Add Volcengine support for TTS and ASR services
- Introduced Volcengine as a new provider for both TTS and ASR services. - Updated configuration files to include Volcengine-specific parameters such as app_id, resource_id, and uid. - Enhanced the ASR service to support streaming mode with Volcengine's API. - Modified existing tests to validate the integration of Volcengine services. - Updated documentation to reflect the addition of Volcengine as a supported provider for TTS and ASR. - Refactored service factory to accommodate Volcengine alongside existing providers.
This commit is contained in:
219
engine/providers/tts/volcengine.py
Normal file
219
engine/providers/tts/volcengine.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Volcengine TTS service.
|
||||
|
||||
Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed
|
||||
base64 audio chunks into engine-native ``TTSChunk`` events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import codecs
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
|
||||
|
||||
|
||||
class VolcengineTTSService(BaseTTSService):
|
||||
"""Streaming TTS adapter for Volcengine's HTTP v3 API."""
|
||||
|
||||
DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
DEFAULT_RESOURCE_ID = "seed-tts-2.0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "zh_female_shuangkuaisisi_moon_bigtts",
|
||||
model: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
uid: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY")
|
||||
self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL
|
||||
self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None
|
||||
self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID")
|
||||
self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID
|
||||
self.uid = uid or os.getenv("VOLCENGINE_TTS_UID")
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
self._pending_audio: list[bytes] = []
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}",
|
||||
self.voice,
|
||||
self.sample_rate,
|
||||
self.resource_id,
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Volcengine TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
audio = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio += chunk.audio
|
||||
return audio
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
if not self._session:
|
||||
raise RuntimeError("Volcengine TTS service not connected")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._synthesis_lock:
|
||||
self._cancel_event.clear()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Api-App-Key": str(self.app_id),
|
||||
"X-Api-Access-Key": str(self.api_key),
|
||||
"X-Api-Resource-Id": str(self.resource_id),
|
||||
"X-Api-Request-Id": str(uuid.uuid4()),
|
||||
}
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": str(self.uid or self.app_id),
|
||||
},
|
||||
"req_params": {
|
||||
"text": text,
|
||||
"speaker": self.voice,
|
||||
"audio_params": {
|
||||
"format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"speech_rate": self._speech_rate_percent(self.speed),
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.model:
|
||||
payload["req_params"]["model"] = self.model
|
||||
|
||||
chunk_size = max(1, self.sample_rate * 2 // 10)
|
||||
audio_buffer = b""
|
||||
pending_chunk: Optional[bytes] = None
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}")
|
||||
|
||||
async for audio_bytes in self._iter_audio_bytes(response):
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("Volcengine TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
audio_buffer += audio_bytes
|
||||
while len(audio_buffer) >= chunk_size:
|
||||
emitted = audio_buffer[:chunk_size]
|
||||
audio_buffer = audio_buffer[chunk_size:]
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = emitted
|
||||
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
if pending_chunk is not None:
|
||||
if audio_buffer:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
||||
pending_chunk = None
|
||||
|
||||
if audio_buffer:
|
||||
yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Volcengine TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Volcengine TTS synthesis error: {}", exc)
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
|
||||
async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]:
|
||||
decoder = json.JSONDecoder()
|
||||
utf8_decoder = codecs.getincrementaldecoder("utf-8")()
|
||||
text_buffer = ""
|
||||
self._pending_audio.clear()
|
||||
|
||||
async for raw_chunk in response.content.iter_any():
|
||||
text_buffer += utf8_decoder.decode(raw_chunk)
|
||||
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
|
||||
while self._pending_audio:
|
||||
yield self._pending_audio.pop(0)
|
||||
|
||||
text_buffer += utf8_decoder.decode(b"", final=True)
|
||||
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
|
||||
while self._pending_audio:
|
||||
yield self._pending_audio.pop(0)
|
||||
|
||||
def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str:
|
||||
while True:
|
||||
stripped = text_buffer.lstrip()
|
||||
if not stripped:
|
||||
return ""
|
||||
if len(stripped) != len(text_buffer):
|
||||
text_buffer = stripped
|
||||
|
||||
try:
|
||||
payload, idx = decoder.raw_decode(text_buffer)
|
||||
except json.JSONDecodeError:
|
||||
return text_buffer
|
||||
|
||||
text_buffer = text_buffer[idx:]
|
||||
audio = self._extract_audio_bytes(payload)
|
||||
if audio:
|
||||
self._pending_audio.append(audio)
|
||||
|
||||
def _extract_audio_bytes(self, payload: Any) -> bytes:
|
||||
if not isinstance(payload, dict):
|
||||
return b""
|
||||
|
||||
code = payload.get("code")
|
||||
if code not in (None, 0, 20000000):
|
||||
message = str(payload.get("message") or "unknown error")
|
||||
raise RuntimeError(f"Volcengine TTS stream error {code}: {message}")
|
||||
|
||||
encoded = payload.get("data")
|
||||
if isinstance(encoded, str) and encoded.strip():
|
||||
try:
|
||||
return base64.b64decode(encoded)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc)
|
||||
return b""
|
||||
|
||||
@staticmethod
|
||||
def _speech_rate_percent(speed: float) -> int:
|
||||
clamped = max(0.5, min(2.0, float(speed or 1.0)))
|
||||
return int(round((clamped - 1.0) * 100))
|
||||
Reference in New Issue
Block a user