RimeTTSService: use AudioContextWordTTSService

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-14 11:39:05 -08:00
parent aeadb40c3f
commit f53ee79ddb

View File

@@ -7,7 +7,7 @@
import base64
import json
import uuid
from typing import AsyncGenerator, Optional, Union
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
@@ -28,7 +28,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService, WordTTSService
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language
@@ -58,7 +58,7 @@ def language_to_rime_language(language: Language) -> str:
return LANGUAGE_MAP.get(language, "eng")
class RimeTTSService(WordTTSService, WebsocketService):
class RimeTTSService(AudioContextWordTTSService, WebsocketService):
"""Text-to-Speech service using Rime's websocket API.
Uses Rime's websocket JSON API to convert text to speech with word-level timing
@@ -95,7 +95,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
params: Additional configuration parameters.
"""
# Initialize with parent class settings for proper frame handling
WordTTSService.__init__(
AudioContextWordTTSService.__init__(
self,
aggregate_sentences=True,
push_text_frames=False,
@@ -249,12 +249,18 @@ class RimeTTSService(WordTTSService, WebsocketService):
return word_pairs
async def flush_audio(self):
if not self._context_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
self._context_id = None
async def _receive_messages(self):
"""Process incoming websocket messages."""
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["contextId"] != self._context_id:
if not msg or not self.audio_context_available(msg["contextId"]):
continue
if msg["type"] == "chunk":
@@ -266,7 +272,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
sample_rate=self.sample_rate,
num_channels=1,
)
await self.push_frame(frame)
await self.append_to_audio_context(msg["contextId"], frame)
elif msg["type"] == "timestamps":
# Process word timing information
@@ -288,6 +294,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
self._context_id = None
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push frame and handle end-of-turn conditions."""
@@ -329,6 +336,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
self._started = True
self._cumulative_time = 0
self._context_id = str(uuid.uuid4())
await self.create_audio_context(self._context_id)
msg = self._build_msg(text=text)
await self._get_websocket().send(json.dumps(msg))