Improve Async TTS interruption handling by using AudioContextTTSService class and add changelog fragments

This commit is contained in:
Ashot
2026-01-07 15:55:35 +04:00
parent 9cdbc56be3
commit 5ae592f38e
4 changed files with 5 additions and 126 deletions

View File

@@ -6,15 +6,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
<!-- towncrier release notes start -->
## [Unreleased]
### Changed
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
### Fixed
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
## [0.0.99] - 2026-01-13

View File

@@ -0,0 +1 @@
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.

1
changelog/3287.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.

View File

@@ -28,7 +28,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import WebsocketTTSService, TTSService
from pipecat.services.tts_service import AudioContextTTSService, WebsocketTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -73,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
class AsyncAITTSService(WebsocketTTSService):
class AsyncAITTSService(AudioContextTTSService, WebsocketTTSService):
"""Async TTS service with WebSocket streaming.
Provides text-to-speech using Async's streaming WebSocket API.
@@ -154,55 +154,6 @@ class AsyncAITTSService(WebsocketTTSService):
self._keepalive_task = None
self._started = False
async def create_audio_context(self, context_id: str):
"""Create a new audio context for grouping related audio.
Args:
context_id: Unique identifier for the audio context.
"""
await self._contexts_queue.put(context_id)
self._contexts[context_id] = asyncio.Queue()
logger.trace(f"{self} created audio context {context_id}")
async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame):
"""Append audio to an existing context.
Args:
context_id: The context to append audio to.
frame: The audio frame to append.
"""
if self.audio_context_available(context_id):
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
await self._contexts[context_id].put(frame)
else:
logger.warning(f"{self} unable to append audio to context {context_id}")
async def remove_audio_context(self, context_id: str):
"""Remove an existing audio context.
Args:
context_id: The context to remove.
"""
if self.audio_context_available(context_id):
# We just mark the audio context for deletion by appending
# None. Once we reach None while handling audio we know we can
# safely remove the context.
logger.trace(f"{self} marking audio context {context_id} for deletion")
await self._contexts[context_id].put(None)
else:
logger.warning(f"{self} unable to remove context {context_id}")
def audio_context_available(self, context_id: str) -> bool:
"""Check whether the given audio context is registered.
Args:
context_id: The context ID to check.
Returns:
True if the context exists and is available.
"""
return context_id in self._contexts
async def start(self, frame: StartFrame):
"""Start the Async TTS service.
@@ -210,7 +161,6 @@ class AsyncAITTSService(WebsocketTTSService):
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._create_audio_context_task()
self._settings["output_format"]["sample_rate"] = self.sample_rate
await self._connect()
@@ -221,12 +171,6 @@ class AsyncAITTSService(WebsocketTTSService):
frame: The end frame.
"""
await super().stop(frame)
if self._audio_context_task:
# Indicate no more audio contexts are available. this will end the
# task cleanly after all contexts have been processed.
await self._contexts_queue.put(None)
await self._audio_context_task
self._audio_context_task = None
await self._disconnect()
async def cancel(self, frame: CancelFrame):
@@ -236,65 +180,7 @@ class AsyncAITTSService(WebsocketTTSService):
frame: The cancel frame.
"""
await super().cancel(frame)
await self._stop_audio_context_task()
await self._disconnect()
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self._stop_audio_context_task()
self._create_audio_context_task()
def _create_audio_context_task(self):
if not self._audio_context_task:
self._contexts_queue = asyncio.Queue()
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = self.create_task(self._audio_context_task_handler())
async def _stop_audio_context_task(self):
if self._audio_context_task:
await self.cancel_task(self._audio_context_task)
self._audio_context_task = None
async def _audio_context_task_handler(self):
"""In this task we process audio contexts in order."""
running = True
while running:
context_id = await self._contexts_queue.get()
if context_id:
# Process the audio context until the context doesn't have more
# audio available (i.e. we find None).
await self._handle_audio_context(context_id)
# We just finished processing the context, so we can safely remove it.
del self._contexts[context_id]
# Append some silence between sentences.
silence = b"\x00" * self.sample_rate
frame = TTSAudioRawFrame(
audio=silence, sample_rate=self.sample_rate, num_channels=1
)
await self.push_frame(frame)
else:
running = False
self._contexts_queue.task_done()
async def _handle_audio_context(self, context_id: str):
# If we don't receive any audio during this time, we consider the context finished.
AUDIO_CONTEXT_TIMEOUT = 3.0
queue = self._contexts[context_id]
running = True
while running:
try:
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
if frame:
await self.push_frame(frame)
running = frame is not None
except asyncio.TimeoutError:
# We didn't get audio, so let's consider this context finished.
logger.trace(f"{self} time out on audio context {context_id}")
break
await self._disconnect()
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.