Update Rime TTS services to store voice in the standard settings.voice field, as opposed to the nonstandard speaker field

This commit is contained in:
Paul Kompfner
2026-02-18 14:33:12 -05:00
parent b4c5cb258b
commit 416e1cf877
2 changed files with 12 additions and 16 deletions

View File

@@ -100,8 +100,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
await task.queue_frames([LLMRunFrame()])
await asyncio.sleep(10)
logger.info("Updating Rime TTS settings: speedAlpha=1.5")
await task.queue_frame(TTSUpdateSettingsFrame(update=RimeTTSSettings(speedAlpha=1.5)))
logger.info("Updating Rime TTS settings: voice=rex")
await task.queue_frame(TTSUpdateSettingsFrame(update=RimeTTSSettings(voice="rex")))
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):

View File

@@ -13,7 +13,7 @@ using Rime's API for streaming and batch audio synthesis.
import base64
import json
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator, ClassVar, Dict, Optional
import aiohttp
from loguru import logger
@@ -75,7 +75,6 @@ class RimeTTSSettings(TTSSettings):
"""Settings for Rime WS JSON and HTTP TTS services.
Parameters:
speaker: Voice speaker ID.
modelId: Rime model identifier.
audioFormat: Audio output format.
samplingRate: Audio sample rate.
@@ -87,7 +86,6 @@ class RimeTTSSettings(TTSSettings):
inlineSpeedAlpha: Inline speed control markup.
"""
speaker: str = field(default_factory=lambda: NOT_GIVEN)
modelId: str = field(default_factory=lambda: NOT_GIVEN)
audioFormat: str = field(default_factory=lambda: NOT_GIVEN)
samplingRate: int = field(default_factory=lambda: NOT_GIVEN)
@@ -98,13 +96,14 @@ class RimeTTSSettings(TTSSettings):
phonemizeBetweenBrackets: bool = field(default_factory=lambda: NOT_GIVEN)
inlineSpeedAlpha: str = field(default_factory=lambda: NOT_GIVEN)
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
@dataclass
class RimeNonJsonTTSSettings(TTSSettings):
"""Settings for Rime non-JSON WS TTS service.
Parameters:
speaker: Voice speaker ID.
modelId: Rime model identifier.
audioFormat: Audio output format.
samplingRate: Audio sample rate.
@@ -115,7 +114,6 @@ class RimeNonJsonTTSSettings(TTSSettings):
top_p: Cumulative probability threshold (0.0-1.0).
"""
speaker: str = field(default_factory=lambda: NOT_GIVEN)
modelId: str = field(default_factory=lambda: NOT_GIVEN)
audioFormat: str = field(default_factory=lambda: NOT_GIVEN)
samplingRate: int = field(default_factory=lambda: NOT_GIVEN)
@@ -125,6 +123,8 @@ class RimeNonJsonTTSSettings(TTSSettings):
temperature: float = field(default_factory=lambda: NOT_GIVEN)
top_p: float = field(default_factory=lambda: NOT_GIVEN)
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
class RimeTTSService(AudioContextWordTTSService):
"""Text-to-Speech service using Rime's websocket API.
@@ -210,7 +210,7 @@ class RimeTTSService(AudioContextWordTTSService):
self._voice_id = voice_id
self._model = model
self._settings = RimeTTSSettings(
speaker=voice_id,
voice=voice_id,
modelId=model,
audioFormat="pcm",
samplingRate=0,
@@ -273,10 +273,8 @@ class RimeTTSService(AudioContextWordTTSService):
async def _update_settings(self, update: TTSSettings) -> dict[str, Any]:
"""Apply a settings update and reconnect if voice changed."""
prev_voice = self._voice_id
changed = await super()._update_settings(update)
if "voice" in changed:
self._settings.speaker = self._voice_id
await self._disconnect()
await self._connect()
else:
@@ -355,7 +353,7 @@ class RimeTTSService(AudioContextWordTTSService):
params = "&".join(
f"{k}={v}"
for k, v in {
"speaker": self._settings.speaker,
"speaker": self._settings.voice,
"modelId": self._settings.modelId,
"audioFormat": self._settings.audioFormat,
"samplingRate": self._settings.samplingRate,
@@ -772,7 +770,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
self._voice_id = voice_id
self._model = model
self._settings = RimeNonJsonTTSSettings(
speaker=voice_id,
voice=voice_id,
modelId=model,
audioFormat=audio_format,
samplingRate=sample_rate,
@@ -866,7 +864,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
return
# Build URL with query parameters (only given, non-None values)
settings_dict = {
"speaker": self._settings.speaker,
"speaker": self._settings.voice,
"modelId": self._settings.modelId,
"audioFormat": self._settings.audioFormat,
"samplingRate": self._settings.samplingRate,
@@ -985,9 +983,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
"""
changed = await super()._update_settings(update)
# Sync voice and model to settings dict fields
if "voice" in changed:
self._settings.speaker = self._voice_id
# Sync model to settings dict field
if "model" in changed:
self._settings.modelId = self._model_name