Refactored all AudioContextTTSService based providers to override the new callbacks instead of _handle_interruption(), making provider-specific cleanup cleaner and more explicit

This commit is contained in:
filipi87
2026-02-25 10:18:16 -03:00
parent c09ae6ba6d
commit d899f0af11
7 changed files with 101 additions and 67 deletions

View File

@@ -10,7 +10,7 @@ import asyncio
import base64
import json
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, ClassVar, Dict, Mapping, Optional
from typing import Any, AsyncGenerator, Mapping, Optional
import aiohttp
from loguru import logger
@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
@@ -392,18 +391,29 @@ class AsyncAITTSService(AudioContextTTSService):
logger.warning(f"{self} keepalive error: {e}")
break
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by closing the current context."""
context_id = self.get_active_audio_context_id()
await super()._handle_interruption(frame, direction)
# Close the current context when interrupted without closing the websocket
async def _close_context(self, context_id: str):
# Async AI requires explicit context closure to free server-side resources,
# both on interruption and on normal completion.
if context_id and self._websocket:
try:
await self._websocket.send(
json.dumps({"context_id": context_id, "close_context": True, "transcript": ""})
)
except Exception as e:
logger.error(f"Error closing context on interruption: {e}")
logger.error(f"{self}: Error closing context {context_id}: {e}")
async def on_audio_context_interrupted(self, context_id: str):
"""Close the Async AI context when the bot is interrupted."""
await self._close_context(context_id)
async def on_audio_context_completed(self, context_id: str):
"""Close the Async AI context after all audio has been played.
Async AI does not send a server-side signal when a context is
exhausted, so Pipecat must explicitly close it with
``close_context: True`` to free server-side resources.
"""
await self._close_context(context_id)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:

View File

@@ -11,7 +11,7 @@ import json
import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncGenerator, ClassVar, Dict, List, Literal, Mapping, Optional
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional
from loguru import logger
from pydantic import BaseModel, Field
@@ -21,13 +21,11 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import AudioContextTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
@@ -563,14 +561,22 @@ class CartesiaTTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
context_id = self.get_active_audio_context_id()
await super()._handle_interruption(frame, direction)
async def on_audio_context_interrupted(self, context_id: str):
"""Cancel the active Cartesia context when the bot is interrupted."""
await self.stop_all_metrics()
if context_id:
cancel_msg = json.dumps({"context_id": context_id, "cancel": True})
await self._get_websocket().send(cancel_msg)
async def on_audio_context_completed(self, context_id: str):
"""Close the Cartesia context after all audio has been played.
No close message is needed: the server already considers the context
done once it has sent its ``done`` message, which is handled in
``_process_messages``.
"""
pass
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
context_id = self.get_active_audio_context_id()

View File

@@ -666,14 +666,11 @@ class ElevenLabsTTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by closing the current context."""
# Close the current context when interrupted without closing the websocket
context_id = self.get_active_audio_context_id()
await super()._handle_interruption(frame, direction)
async def _close_context(self, context_id: str):
# ElevenLabs requires that Pipecat explicitly closes contexts to free
# server-side resources, both on interruption and on normal completion.
if context_id and self._websocket:
logger.trace(f"Closing context {context_id} due to interruption")
logger.trace(f"{self}: Closing context {context_id}")
try:
# ElevenLabs requires that Pipecat manages the contexts and closes them
# when they're not longer in use. Since an InterruptionFrame is pushed
@@ -686,8 +683,21 @@ class ElevenLabsTTSService(AudioContextTTSService):
)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._partial_word = ""
self._partial_word_start_time = 0.0
self._partial_word = ""
self._partial_word_start_time = 0.0
async def on_audio_context_interrupted(self, context_id: str):
"""Close the ElevenLabs context when the bot is interrupted."""
await self._close_context(context_id)
async def on_audio_context_completed(self, context_id: str):
"""Close the ElevenLabs context after all audio has been played.
ElevenLabs does not send a server-side signal when a context is
exhausted, so Pipecat must explicitly close it with
``close_context: True`` to free server-side resources.
"""
await self._close_context(context_id)
async def _receive_messages(self):
"""Handle incoming WebSocket messages from ElevenLabs."""

View File

@@ -17,13 +17,11 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -265,21 +263,24 @@ class GradiumTTSService(AudioContextTTSService):
except Exception as e:
logger.error(f"{self} exception: {e}")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by resetting context state.
async def on_audio_context_interrupted(self, context_id: str):
"""Called when an audio context is cancelled due to an interruption.
The parent AudioContextTTSService._handle_interruption() cancels the audio context
task and creates a new one. We reset _context_id so the next run_tts() creates a
fresh context. No websocket reconnection needed — audio from the old client_req_id
will be silently dropped since the audio context no longer exists.
Args:
frame: The interruption frame.
direction: The direction of the frame.
No WebSocket message is needed — audio from the interrupted
``client_req_id`` will be silently dropped by the base class once the
audio context no longer exists.
"""
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
async def on_audio_context_completed(self, context_id: str):
"""Called after an audio context has finished playing all of its audio.
No close message is needed: Gradium signals completion with an
``end_of_stream`` message (handled in ``_receive_messages``), after
which the server-side context is already closed.
"""
pass
async def _receive_messages(self):
"""Process incoming websocket messages, demultiplexing by client_req_id."""
# TODO(laurent): This should not be necessary as it should happen when

View File

@@ -681,28 +681,23 @@ class InworldTTSService(AudioContextTTSService):
return word_times
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle an interruption from the Inworld WebSocket TTS service.
Args:
frame: The interruption frame.
direction: The direction of the interruption.
"""
old_context_id = self.get_active_audio_context_id()
logger.trace(f"{self}: Handling interruption, old context: {old_context_id}")
await super()._handle_interruption(frame, direction)
if old_context_id and self._websocket:
logger.trace(f"{self}: Closing context {old_context_id} due to interruption")
async def _close_context(self, context_id: str):
if context_id and self._websocket:
logger.info(f"{self}: Closing context {context_id} due to interruption or completion")
try:
await self._send_close_context(old_context_id)
await self._send_close_context(context_id)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._cumulative_time = 0.0
self._generation_end_time = 0.0
logger.trace(f"{self}: Interruption handled, context reset to None")
async def on_audio_context_interrupted(self, context_id: str):
"""Callback invoked when an audio context has been interrupted."""
await self._close_context(context_id)
async def on_audio_context_completed(self, context_id: str):
"""Callback invoked when an audio context has been completed."""
await self._close_context(context_id)
def _get_websocket(self):
"""Get the websocket for the Inworld WebSocket TTS service.

View File

@@ -18,13 +18,11 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -247,16 +245,19 @@ class ResembleAITTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by stopping current synthesis.
Args:
frame: The interruption frame.
direction: The direction of frame processing.
"""
await super()._handle_interruption(frame, direction)
async def on_audio_context_interrupted(self, context_id: str):
"""Stop metrics when the bot is interrupted."""
await self.stop_all_metrics()
async def on_audio_context_completed(self, context_id: str):
"""Stop metrics after the Resemble AI context finishes playing.
No close message is needed: Resemble AI signals completion with an
``audio_end`` message (handled in ``_process_messages``), after which
the server-side context is already closed.
"""
pass
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
logger.trace(f"{self}: flushing audio")

View File

@@ -458,14 +458,25 @@ class RimeTTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by clearing current context."""
context_id = self.get_active_audio_context_id()
await super()._handle_interruption(frame, direction)
async def _close_context(self, context_id: str):
"""Clear the Rime speech queue and stop metrics."""
await self.stop_all_metrics()
if context_id:
await self._get_websocket().send(json.dumps(self._build_clear_msg()))
async def on_audio_context_interrupted(self, context_id: str):
"""Clear the Rime speech queue and stop metrics when the bot is interrupted."""
await self._close_context(context_id)
async def on_audio_context_completed(self, context_id: str):
"""Clear server-side state and stop metrics after the Rime context finishes playing.
Rime does not send a server-side completion signal (e.g. ``done`` / ``end_of_stream`` /
``audio_end``), so we explicitly send a ``clear`` message to clean up
any residual server-side state once all audio has been delivered.
"""
await self._close_context(context_id)
def _calculate_word_times(self, words: list, starts: list, ends: list) -> list:
"""Calculate word timing pairs with proper spacing and punctuation.