Improve Async TTS interruption handling by using AudioContextTTSService class and add changelog fragments
This commit is contained in:
@@ -6,15 +6,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
<!-- towncrier release notes start -->
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
|
||||
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
|
||||
|
||||
## [0.0.99] - 2026-01-13
|
||||
|
||||
|
||||
1
changelog/3287.changed.md
Normal file
1
changelog/3287.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
|
||||
1
changelog/3287.fixed.md
Normal file
1
changelog/3287.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
|
||||
@@ -28,7 +28,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import WebsocketTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, WebsocketTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -73,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class AsyncAITTSService(WebsocketTTSService):
|
||||
class AsyncAITTSService(AudioContextTTSService, WebsocketTTSService):
|
||||
"""Async TTS service with WebSocket streaming.
|
||||
|
||||
Provides text-to-speech using Async's streaming WebSocket API.
|
||||
@@ -154,55 +154,6 @@ class AsyncAITTSService(WebsocketTTSService):
|
||||
self._keepalive_task = None
|
||||
self._started = False
|
||||
|
||||
async def create_audio_context(self, context_id: str):
|
||||
"""Create a new audio context for grouping related audio.
|
||||
|
||||
Args:
|
||||
context_id: Unique identifier for the audio context.
|
||||
"""
|
||||
await self._contexts_queue.put(context_id)
|
||||
self._contexts[context_id] = asyncio.Queue()
|
||||
logger.trace(f"{self} created audio context {context_id}")
|
||||
|
||||
async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame):
|
||||
"""Append audio to an existing context.
|
||||
|
||||
Args:
|
||||
context_id: The context to append audio to.
|
||||
frame: The audio frame to append.
|
||||
"""
|
||||
if self.audio_context_available(context_id):
|
||||
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
|
||||
await self._contexts[context_id].put(frame)
|
||||
else:
|
||||
logger.warning(f"{self} unable to append audio to context {context_id}")
|
||||
|
||||
async def remove_audio_context(self, context_id: str):
|
||||
"""Remove an existing audio context.
|
||||
|
||||
Args:
|
||||
context_id: The context to remove.
|
||||
"""
|
||||
if self.audio_context_available(context_id):
|
||||
# We just mark the audio context for deletion by appending
|
||||
# None. Once we reach None while handling audio we know we can
|
||||
# safely remove the context.
|
||||
logger.trace(f"{self} marking audio context {context_id} for deletion")
|
||||
await self._contexts[context_id].put(None)
|
||||
else:
|
||||
logger.warning(f"{self} unable to remove context {context_id}")
|
||||
|
||||
def audio_context_available(self, context_id: str) -> bool:
|
||||
"""Check whether the given audio context is registered.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to check.
|
||||
|
||||
Returns:
|
||||
True if the context exists and is available.
|
||||
"""
|
||||
return context_id in self._contexts
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Async TTS service.
|
||||
|
||||
@@ -210,7 +161,6 @@ class AsyncAITTSService(WebsocketTTSService):
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._create_audio_context_task()
|
||||
self._settings["output_format"]["sample_rate"] = self.sample_rate
|
||||
await self._connect()
|
||||
|
||||
@@ -221,12 +171,6 @@ class AsyncAITTSService(WebsocketTTSService):
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if self._audio_context_task:
|
||||
# Indicate no more audio contexts are available. this will end the
|
||||
# task cleanly after all contexts have been processed.
|
||||
await self._contexts_queue.put(None)
|
||||
await self._audio_context_task
|
||||
self._audio_context_task = None
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
@@ -236,65 +180,7 @@ class AsyncAITTSService(WebsocketTTSService):
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_audio_context_task()
|
||||
await self._disconnect()
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self._stop_audio_context_task()
|
||||
self._create_audio_context_task()
|
||||
|
||||
def _create_audio_context_task(self):
|
||||
if not self._audio_context_task:
|
||||
self._contexts_queue = asyncio.Queue()
|
||||
self._contexts: Dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = self.create_task(self._audio_context_task_handler())
|
||||
|
||||
async def _stop_audio_context_task(self):
|
||||
if self._audio_context_task:
|
||||
await self.cancel_task(self._audio_context_task)
|
||||
self._audio_context_task = None
|
||||
|
||||
async def _audio_context_task_handler(self):
|
||||
"""In this task we process audio contexts in order."""
|
||||
running = True
|
||||
while running:
|
||||
context_id = await self._contexts_queue.get()
|
||||
|
||||
if context_id:
|
||||
# Process the audio context until the context doesn't have more
|
||||
# audio available (i.e. we find None).
|
||||
await self._handle_audio_context(context_id)
|
||||
|
||||
# We just finished processing the context, so we can safely remove it.
|
||||
del self._contexts[context_id]
|
||||
|
||||
# Append some silence between sentences.
|
||||
silence = b"\x00" * self.sample_rate
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=silence, sample_rate=self.sample_rate, num_channels=1
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
running = False
|
||||
|
||||
self._contexts_queue.task_done()
|
||||
|
||||
async def _handle_audio_context(self, context_id: str):
|
||||
# If we don't receive any audio during this time, we consider the context finished.
|
||||
AUDIO_CONTEXT_TIMEOUT = 3.0
|
||||
queue = self._contexts[context_id]
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
running = frame is not None
|
||||
except asyncio.TimeoutError:
|
||||
# We didn't get audio, so let's consider this context finished.
|
||||
logger.trace(f"{self} time out on audio context {context_id}")
|
||||
break
|
||||
await self._disconnect()
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Reference in New Issue
Block a user