RimeTTSService: use AudioContextWordTTSService
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional, Union
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -28,7 +28,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService, WordTTSService
|
||||
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
@@ -58,7 +58,7 @@ def language_to_rime_language(language: Language) -> str:
|
||||
return LANGUAGE_MAP.get(language, "eng")
|
||||
|
||||
|
||||
class RimeTTSService(WordTTSService, WebsocketService):
|
||||
class RimeTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
"""Text-to-Speech service using Rime's websocket API.
|
||||
|
||||
Uses Rime's websocket JSON API to convert text to speech with word-level timing
|
||||
@@ -95,7 +95,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
|
||||
params: Additional configuration parameters.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
WordTTSService.__init__(
|
||||
AudioContextWordTTSService.__init__(
|
||||
self,
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=False,
|
||||
@@ -249,12 +249,18 @@ class RimeTTSService(WordTTSService, WebsocketService):
|
||||
|
||||
return word_pairs
|
||||
|
||||
async def flush_audio(self):
|
||||
if not self._context_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
self._context_id = None
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
if not msg or msg["contextId"] != self._context_id:
|
||||
if not msg or not self.audio_context_available(msg["contextId"]):
|
||||
continue
|
||||
|
||||
if msg["type"] == "chunk":
|
||||
@@ -266,7 +272,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(msg["contextId"], frame)
|
||||
|
||||
elif msg["type"] == "timestamps":
|
||||
# Process word timing information
|
||||
@@ -288,6 +294,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
|
||||
self._context_id = None
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions."""
|
||||
@@ -329,6 +336,7 @@ class RimeTTSService(WordTTSService, WebsocketService):
|
||||
self._started = True
|
||||
self._cumulative_time = 0
|
||||
self._context_id = str(uuid.uuid4())
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
|
||||
Reference in New Issue
Block a user