Update Rime TTS services to store voice in the standard settings.voice field, as opposed to the nonstandard speaker field
This commit is contained in:
@@ -100,8 +100,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
await asyncio.sleep(10)
|
||||
logger.info("Updating Rime TTS settings: speedAlpha=1.5")
|
||||
await task.queue_frame(TTSUpdateSettingsFrame(update=RimeTTSSettings(speedAlpha=1.5)))
|
||||
logger.info("Updating Rime TTS settings: voice=rex")
|
||||
await task.queue_frame(TTSUpdateSettingsFrame(update=RimeTTSSettings(voice="rex")))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -13,7 +13,7 @@ using Rime's API for streaming and batch audio synthesis.
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
from typing import Any, AsyncGenerator, ClassVar, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -75,7 +75,6 @@ class RimeTTSSettings(TTSSettings):
|
||||
"""Settings for Rime WS JSON and HTTP TTS services.
|
||||
|
||||
Parameters:
|
||||
speaker: Voice speaker ID.
|
||||
modelId: Rime model identifier.
|
||||
audioFormat: Audio output format.
|
||||
samplingRate: Audio sample rate.
|
||||
@@ -87,7 +86,6 @@ class RimeTTSSettings(TTSSettings):
|
||||
inlineSpeedAlpha: Inline speed control markup.
|
||||
"""
|
||||
|
||||
speaker: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
modelId: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
audioFormat: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
samplingRate: int = field(default_factory=lambda: NOT_GIVEN)
|
||||
@@ -98,13 +96,14 @@ class RimeTTSSettings(TTSSettings):
|
||||
phonemizeBetweenBrackets: bool = field(default_factory=lambda: NOT_GIVEN)
|
||||
inlineSpeedAlpha: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RimeNonJsonTTSSettings(TTSSettings):
|
||||
"""Settings for Rime non-JSON WS TTS service.
|
||||
|
||||
Parameters:
|
||||
speaker: Voice speaker ID.
|
||||
modelId: Rime model identifier.
|
||||
audioFormat: Audio output format.
|
||||
samplingRate: Audio sample rate.
|
||||
@@ -115,7 +114,6 @@ class RimeNonJsonTTSSettings(TTSSettings):
|
||||
top_p: Cumulative probability threshold (0.0-1.0).
|
||||
"""
|
||||
|
||||
speaker: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
modelId: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
audioFormat: str = field(default_factory=lambda: NOT_GIVEN)
|
||||
samplingRate: int = field(default_factory=lambda: NOT_GIVEN)
|
||||
@@ -125,6 +123,8 @@ class RimeNonJsonTTSSettings(TTSSettings):
|
||||
temperature: float = field(default_factory=lambda: NOT_GIVEN)
|
||||
top_p: float = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
|
||||
|
||||
|
||||
class RimeTTSService(AudioContextWordTTSService):
|
||||
"""Text-to-Speech service using Rime's websocket API.
|
||||
@@ -210,7 +210,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._voice_id = voice_id
|
||||
self._model = model
|
||||
self._settings = RimeTTSSettings(
|
||||
speaker=voice_id,
|
||||
voice=voice_id,
|
||||
modelId=model,
|
||||
audioFormat="pcm",
|
||||
samplingRate=0,
|
||||
@@ -273,10 +273,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _update_settings(self, update: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings update and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
changed = await super()._update_settings(update)
|
||||
if "voice" in changed:
|
||||
self._settings.speaker = self._voice_id
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
else:
|
||||
@@ -355,7 +353,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
params = "&".join(
|
||||
f"{k}={v}"
|
||||
for k, v in {
|
||||
"speaker": self._settings.speaker,
|
||||
"speaker": self._settings.voice,
|
||||
"modelId": self._settings.modelId,
|
||||
"audioFormat": self._settings.audioFormat,
|
||||
"samplingRate": self._settings.samplingRate,
|
||||
@@ -772,7 +770,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
self._voice_id = voice_id
|
||||
self._model = model
|
||||
self._settings = RimeNonJsonTTSSettings(
|
||||
speaker=voice_id,
|
||||
voice=voice_id,
|
||||
modelId=model,
|
||||
audioFormat=audio_format,
|
||||
samplingRate=sample_rate,
|
||||
@@ -866,7 +864,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
return
|
||||
# Build URL with query parameters (only given, non-None values)
|
||||
settings_dict = {
|
||||
"speaker": self._settings.speaker,
|
||||
"speaker": self._settings.voice,
|
||||
"modelId": self._settings.modelId,
|
||||
"audioFormat": self._settings.audioFormat,
|
||||
"samplingRate": self._settings.samplingRate,
|
||||
@@ -985,9 +983,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
changed = await super()._update_settings(update)
|
||||
|
||||
# Sync voice and model to settings dict fields
|
||||
if "voice" in changed:
|
||||
self._settings.speaker = self._voice_id
|
||||
# Sync model to settings dict field
|
||||
if "model" in changed:
|
||||
self._settings.modelId = self._model_name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user