Fix GradiumTTSService to reuse context IDs across multiple run_tts calls and prevent the parent class from pushing text frames.

This commit is contained in:
filipi87
2026-02-18 12:12:49 -03:00
parent f181e12d8f
commit 1daea78b91

View File

@@ -6,6 +6,7 @@
import base64
import json
import uuid
from typing import Any, AsyncGenerator, Mapping, Optional
from loguru import logger
@@ -74,6 +75,7 @@ class GradiumTTSService(AudioContextWordTTSService):
"""
super().__init__(
push_stop_frames=True,
push_text_frames=False,
pause_frame_processing=True,
sample_rate=SAMPLE_RATE,
**kwargs,
@@ -304,6 +306,20 @@ class GradiumTTSService(AudioContextWordTTSService):
await self.stop_all_metrics()
await self.push_error(error_msg=f"Error: {msg.get('message', msg)}")
def create_context_id(self) -> str:
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
Returns:
A unique string identifier for the TTS context.
"""
# If a context ID does not exist, create a new one.
# If an ID exists, continue using the current ID.
# When interruptions happens, user speech results in
# an interruption, which resets the context ID.
if not self._context_id:
return str(uuid.uuid4())
return self._context_id
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Gradium's streaming API.