diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index 0210f8de6..1634578cd 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -7,7 +7,7 @@ import base64 import json import uuid -from typing import AsyncGenerator, Optional, Union +from typing import AsyncGenerator, Optional import aiohttp from loguru import logger @@ -28,7 +28,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import TTSService, WordTTSService +from pipecat.services.ai_services import AudioContextWordTTSService, TTSService from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language @@ -58,7 +58,7 @@ def language_to_rime_language(language: Language) -> str: return LANGUAGE_MAP.get(language, "eng") -class RimeTTSService(WordTTSService, WebsocketService): +class RimeTTSService(AudioContextWordTTSService, WebsocketService): """Text-to-Speech service using Rime's websocket API. Uses Rime's websocket JSON API to convert text to speech with word-level timing @@ -95,7 +95,7 @@ class RimeTTSService(WordTTSService, WebsocketService): params: Additional configuration parameters. """ # Initialize with parent class settings for proper frame handling - WordTTSService.__init__( + AudioContextWordTTSService.__init__( self, aggregate_sentences=True, push_text_frames=False, @@ -249,12 +249,18 @@ class RimeTTSService(WordTTSService, WebsocketService): return word_pairs + async def flush_audio(self): + if not self._context_id or not self._websocket: + return + logger.trace(f"{self}: flushing audio") + self._context_id = None + async def _receive_messages(self): """Process incoming websocket messages.""" async for message in self._get_websocket(): msg = json.loads(message) - if not msg or msg["contextId"] != self._context_id: + if not msg or not self.audio_context_available(msg["contextId"]): continue if msg["type"] == "chunk": @@ -266,7 +272,7 @@ class RimeTTSService(WordTTSService, WebsocketService): sample_rate=self.sample_rate, num_channels=1, ) - await self.push_frame(frame) + await self.append_to_audio_context(msg["contextId"], frame) elif msg["type"] == "timestamps": # Process word timing information @@ -288,6 +294,7 @@ class RimeTTSService(WordTTSService, WebsocketService): await self.push_frame(TTSStoppedFrame()) await self.stop_all_metrics() await self.push_error(ErrorFrame(f"{self} error: {msg['message']}")) + self._context_id = None async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): """Push frame and handle end-of-turn conditions.""" @@ -329,6 +336,7 @@ class RimeTTSService(WordTTSService, WebsocketService): self._started = True self._cumulative_time = 0 self._context_id = str(uuid.uuid4()) + await self.create_audio_context(self._context_id) msg = self._build_msg(text=text) await self._get_websocket().send(json.dumps(msg))