diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index d55062c4f..f1f73b7ff 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -10,7 +10,7 @@ import asyncio import base64 import json from dataclasses import dataclass, field -from typing import Any, AsyncGenerator, ClassVar, Dict, Mapping, Optional +from typing import Any, AsyncGenerator, Mapping, Optional import aiohttp from loguru import logger @@ -21,7 +21,6 @@ from pipecat.frames.frames import ( EndFrame, ErrorFrame, Frame, - InterruptionFrame, StartFrame, TTSAudioRawFrame, TTSStartedFrame, @@ -392,18 +391,29 @@ class AsyncAITTSService(AudioContextTTSService): logger.warning(f"{self} keepalive error: {e}") break - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle interruption by closing the current context.""" - context_id = self.get_active_audio_context_id() - await super()._handle_interruption(frame, direction) - # Close the current context when interrupted without closing the websocket + async def _close_context(self, context_id: str): + # Async AI requires explicit context closure to free server-side resources, + # both on interruption and on normal completion. if context_id and self._websocket: try: await self._websocket.send( json.dumps({"context_id": context_id, "close_context": True, "transcript": ""}) ) except Exception as e: - logger.error(f"Error closing context on interruption: {e}") + logger.error(f"{self}: Error closing context {context_id}: {e}") + + async def on_audio_context_interrupted(self, context_id: str): + """Close the Async AI context when the bot is interrupted.""" + await self._close_context(context_id) + + async def on_audio_context_completed(self, context_id: str): + """Close the Async AI context after all audio has been played. + + Async AI does not send a server-side signal when a context is + exhausted, so Pipecat must explicitly close it with + ``close_context: True`` to free server-side resources. + """ + await self._close_context(context_id) @traced_tts async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index f31cc2421..f45e7c54f 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -11,7 +11,7 @@ import json import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncGenerator, ClassVar, Dict, List, Literal, Mapping, Optional +from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional from loguru import logger from pydantic import BaseModel, Field @@ -21,13 +21,11 @@ from pipecat.frames.frames import ( EndFrame, ErrorFrame, Frame, - InterruptionFrame, StartFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven from pipecat.services.tts_service import AudioContextTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language @@ -563,14 +561,22 @@ class CartesiaTTSService(AudioContextTTSService): return self._websocket raise Exception("Websocket not connected") - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - context_id = self.get_active_audio_context_id() - await super()._handle_interruption(frame, direction) + async def on_audio_context_interrupted(self, context_id: str): + """Cancel the active Cartesia context when the bot is interrupted.""" await self.stop_all_metrics() if context_id: cancel_msg = json.dumps({"context_id": context_id, "cancel": True}) await self._get_websocket().send(cancel_msg) + async def on_audio_context_completed(self, context_id: str): + """Close the Cartesia context after all audio has been played. + + No close message is needed: the server already considers the context + done once it has sent its ``done`` message, which is handled in + ``_process_messages``. + """ + pass + async def flush_audio(self): """Flush any pending audio and finalize the current context.""" context_id = self.get_active_audio_context_id() diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index 20d46481d..25e1aa5dd 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -666,14 +666,11 @@ class ElevenLabsTTSService(AudioContextTTSService): return self._websocket raise Exception("Websocket not connected") - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle interruption by closing the current context.""" - # Close the current context when interrupted without closing the websocket - context_id = self.get_active_audio_context_id() - await super()._handle_interruption(frame, direction) - + async def _close_context(self, context_id: str): + # ElevenLabs requires that Pipecat explicitly closes contexts to free + # server-side resources, both on interruption and on normal completion. if context_id and self._websocket: - logger.trace(f"Closing context {context_id} due to interruption") + logger.trace(f"{self}: Closing context {context_id}") try: # ElevenLabs requires that Pipecat manages the contexts and closes them # when they're not longer in use. Since an InterruptionFrame is pushed @@ -686,8 +683,21 @@ class ElevenLabsTTSService(AudioContextTTSService): ) except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) - self._partial_word = "" - self._partial_word_start_time = 0.0 + self._partial_word = "" + self._partial_word_start_time = 0.0 + + async def on_audio_context_interrupted(self, context_id: str): + """Close the ElevenLabs context when the bot is interrupted.""" + await self._close_context(context_id) + + async def on_audio_context_completed(self, context_id: str): + """Close the ElevenLabs context after all audio has been played. + + ElevenLabs does not send a server-side signal when a context is + exhausted, so Pipecat must explicitly close it with + ``close_context: True`` to free server-side resources. + """ + await self._close_context(context_id) async def _receive_messages(self): """Handle incoming WebSocket messages from ElevenLabs.""" diff --git a/src/pipecat/services/gradium/tts.py b/src/pipecat/services/gradium/tts.py index 703289706..ee6e6821e 100644 --- a/src/pipecat/services/gradium/tts.py +++ b/src/pipecat/services/gradium/tts.py @@ -17,13 +17,11 @@ from pipecat.frames.frames import ( EndFrame, ErrorFrame, Frame, - InterruptionFrame, StartFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven from pipecat.services.tts_service import AudioContextTTSService from pipecat.utils.tracing.service_decorators import traced_tts @@ -265,21 +263,24 @@ class GradiumTTSService(AudioContextTTSService): except Exception as e: logger.error(f"{self} exception: {e}") - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle interruption by resetting context state. + async def on_audio_context_interrupted(self, context_id: str): + """Called when an audio context is cancelled due to an interruption. - The parent AudioContextTTSService._handle_interruption() cancels the audio context - task and creates a new one. We reset _context_id so the next run_tts() creates a - fresh context. No websocket reconnection needed — audio from the old client_req_id - will be silently dropped since the audio context no longer exists. - - Args: - frame: The interruption frame. - direction: The direction of the frame. + No WebSocket message is needed — audio from the interrupted + ``client_req_id`` will be silently dropped by the base class once the + audio context no longer exists. """ - await super()._handle_interruption(frame, direction) await self.stop_all_metrics() + async def on_audio_context_completed(self, context_id: str): + """Called after an audio context has finished playing all of its audio. + + No close message is needed: Gradium signals completion with an + ``end_of_stream`` message (handled in ``_receive_messages``), after + which the server-side context is already closed. + """ + pass + async def _receive_messages(self): """Process incoming websocket messages, demultiplexing by client_req_id.""" # TODO(laurent): This should not be necessary as it should happen when diff --git a/src/pipecat/services/inworld/tts.py b/src/pipecat/services/inworld/tts.py index 2f35dc27c..22bdf22ff 100644 --- a/src/pipecat/services/inworld/tts.py +++ b/src/pipecat/services/inworld/tts.py @@ -681,28 +681,23 @@ class InworldTTSService(AudioContextTTSService): return word_times - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle an interruption from the Inworld WebSocket TTS service. - - Args: - frame: The interruption frame. - direction: The direction of the interruption. - """ - old_context_id = self.get_active_audio_context_id() - logger.trace(f"{self}: Handling interruption, old context: {old_context_id}") - - await super()._handle_interruption(frame, direction) - - if old_context_id and self._websocket: - logger.trace(f"{self}: Closing context {old_context_id} due to interruption") + async def _close_context(self, context_id: str): + if context_id and self._websocket: + logger.info(f"{self}: Closing context {context_id} due to interruption or completion") try: - await self._send_close_context(old_context_id) + await self._send_close_context(context_id) except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) - self._cumulative_time = 0.0 self._generation_end_time = 0.0 - logger.trace(f"{self}: Interruption handled, context reset to None") + + async def on_audio_context_interrupted(self, context_id: str): + """Callback invoked when an audio context has been interrupted.""" + await self._close_context(context_id) + + async def on_audio_context_completed(self, context_id: str): + """Callback invoked when an audio context has been completed.""" + await self._close_context(context_id) def _get_websocket(self): """Get the websocket for the Inworld WebSocket TTS service. diff --git a/src/pipecat/services/resembleai/tts.py b/src/pipecat/services/resembleai/tts.py index 026d29d3f..c2ac758a7 100644 --- a/src/pipecat/services/resembleai/tts.py +++ b/src/pipecat/services/resembleai/tts.py @@ -18,13 +18,11 @@ from pipecat.frames.frames import ( EndFrame, ErrorFrame, Frame, - InterruptionFrame, StartFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven from pipecat.services.tts_service import AudioContextTTSService from pipecat.utils.tracing.service_decorators import traced_tts @@ -247,16 +245,19 @@ class ResembleAITTSService(AudioContextTTSService): return self._websocket raise Exception("Websocket not connected") - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle interruption by stopping current synthesis. - - Args: - frame: The interruption frame. - direction: The direction of frame processing. - """ - await super()._handle_interruption(frame, direction) + async def on_audio_context_interrupted(self, context_id: str): + """Stop metrics when the bot is interrupted.""" await self.stop_all_metrics() + async def on_audio_context_completed(self, context_id: str): + """Stop metrics after the Resemble AI context finishes playing. + + No close message is needed: Resemble AI signals completion with an + ``audio_end`` message (handled in ``_process_messages``), after which + the server-side context is already closed. + """ + pass + async def flush_audio(self): """Flush any pending audio and finalize the current context.""" logger.trace(f"{self}: flushing audio") diff --git a/src/pipecat/services/rime/tts.py b/src/pipecat/services/rime/tts.py index 248c84008..83c2305d5 100644 --- a/src/pipecat/services/rime/tts.py +++ b/src/pipecat/services/rime/tts.py @@ -458,14 +458,25 @@ class RimeTTSService(AudioContextTTSService): return self._websocket raise Exception("Websocket not connected") - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle interruption by clearing current context.""" - context_id = self.get_active_audio_context_id() - await super()._handle_interruption(frame, direction) + async def _close_context(self, context_id: str): + """Clear the Rime speech queue and stop metrics.""" await self.stop_all_metrics() if context_id: await self._get_websocket().send(json.dumps(self._build_clear_msg())) + async def on_audio_context_interrupted(self, context_id: str): + """Clear the Rime speech queue and stop metrics when the bot is interrupted.""" + await self._close_context(context_id) + + async def on_audio_context_completed(self, context_id: str): + """Clear server-side state and stop metrics after the Rime context finishes playing. + + Rime does not send a server-side completion signal (e.g. ``done`` / ``end_of_stream`` / + ``audio_end``), so we explicitly send a ``clear`` message to clean up + any residual server-side state once all audio has been delivered. + """ + await self._close_context(context_id) + def _calculate_word_times(self, words: list, starts: list, ends: list) -> list: """Calculate word timing pairs with proper spacing and punctuation.