Fix GradiumTTSService to reuse context IDs across multiple run_tts calls and prevent the parent class from pushing text frames.
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
@@ -74,6 +75,7 @@ class GradiumTTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
**kwargs,
|
||||
@@ -304,6 +306,20 @@ class GradiumTTSService(AudioContextWordTTSService):
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg.get('message', msg)}")
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happens, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Gradium's streaming API.
|
||||
|
||||
Reference in New Issue
Block a user