diff --git a/src/pipecat/services/gradium/tts.py b/src/pipecat/services/gradium/tts.py index bde77f846..ef3cbbfde 100644 --- a/src/pipecat/services/gradium/tts.py +++ b/src/pipecat/services/gradium/tts.py @@ -6,6 +6,7 @@ import base64 import json +import uuid from typing import Any, AsyncGenerator, Mapping, Optional from loguru import logger @@ -74,6 +75,7 @@ class GradiumTTSService(AudioContextWordTTSService): """ super().__init__( push_stop_frames=True, + push_text_frames=False, pause_frame_processing=True, sample_rate=SAMPLE_RATE, **kwargs, @@ -304,6 +306,20 @@ class GradiumTTSService(AudioContextWordTTSService): await self.stop_all_metrics() await self.push_error(error_msg=f"Error: {msg.get('message', msg)}") + def create_context_id(self) -> str: + """Generate a unique context ID for a TTS request in case we don't have one already in progress. + + Returns: + A unique string identifier for the TTS context. + """ + # If a context ID does not exist, create a new one. + # If an ID exists, continue using the current ID. + # When interruptions happens, user speech results in + # an interruption, which resets the context ID. + if not self._context_id: + return str(uuid.uuid4()) + return self._context_id + @traced_tts async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: """Generate speech from text using Gradium's streaming API.