Refactored all AudioContextTTSService based providers to override the new callbacks instead of _handle_interruption(), making provider-specific cleanup cleaner and more explicit
This commit is contained in:
@@ -10,7 +10,7 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, ClassVar, Dict, Mapping, Optional
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -392,18 +391,29 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
async def _close_context(self, context_id: str):
|
||||
# Async AI requires explicit context closure to free server-side resources,
|
||||
# both on interruption and on normal completion.
|
||||
if context_id and self._websocket:
|
||||
try:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": context_id, "close_context": True, "transcript": ""})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
logger.error(f"{self}: Error closing context {context_id}: {e}")
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Close the Async AI context when the bot is interrupted."""
|
||||
await self._close_context(context_id)
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Close the Async AI context after all audio has been played.
|
||||
|
||||
Async AI does not send a server-side signal when a context is
|
||||
exhausted, so Pipecat must explicitly close it with
|
||||
``close_context: True`` to free server-side resources.
|
||||
"""
|
||||
await self._close_context(context_id)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
|
||||
@@ -11,7 +11,7 @@ import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, ClassVar, Dict, List, Literal, Mapping, Optional
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -21,13 +21,11 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -563,14 +561,22 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Cancel the active Cartesia context when the bot is interrupted."""
|
||||
await self.stop_all_metrics()
|
||||
if context_id:
|
||||
cancel_msg = json.dumps({"context_id": context_id, "cancel": True})
|
||||
await self._get_websocket().send(cancel_msg)
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Close the Cartesia context after all audio has been played.
|
||||
|
||||
No close message is needed: the server already considers the context
|
||||
done once it has sent its ``done`` message, which is handled in
|
||||
``_process_messages``.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
|
||||
@@ -666,14 +666,11 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
async def _close_context(self, context_id: str):
|
||||
# ElevenLabs requires that Pipecat explicitly closes contexts to free
|
||||
# server-side resources, both on interruption and on normal completion.
|
||||
if context_id and self._websocket:
|
||||
logger.trace(f"Closing context {context_id} due to interruption")
|
||||
logger.trace(f"{self}: Closing context {context_id}")
|
||||
try:
|
||||
# ElevenLabs requires that Pipecat manages the contexts and closes them
|
||||
# when they're not longer in use. Since an InterruptionFrame is pushed
|
||||
@@ -686,8 +683,21 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Close the ElevenLabs context when the bot is interrupted."""
|
||||
await self._close_context(context_id)
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Close the ElevenLabs context after all audio has been played.
|
||||
|
||||
ElevenLabs does not send a server-side signal when a context is
|
||||
exhausted, so Pipecat must explicitly close it with
|
||||
``close_context: True`` to free server-side resources.
|
||||
"""
|
||||
await self._close_context(context_id)
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Handle incoming WebSocket messages from ElevenLabs."""
|
||||
|
||||
@@ -17,13 +17,11 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
@@ -265,21 +263,24 @@ class GradiumTTSService(AudioContextTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by resetting context state.
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Called when an audio context is cancelled due to an interruption.
|
||||
|
||||
The parent AudioContextTTSService._handle_interruption() cancels the audio context
|
||||
task and creates a new one. We reset _context_id so the next run_tts() creates a
|
||||
fresh context. No websocket reconnection needed — audio from the old client_req_id
|
||||
will be silently dropped since the audio context no longer exists.
|
||||
|
||||
Args:
|
||||
frame: The interruption frame.
|
||||
direction: The direction of the frame.
|
||||
No WebSocket message is needed — audio from the interrupted
|
||||
``client_req_id`` will be silently dropped by the base class once the
|
||||
audio context no longer exists.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Called after an audio context has finished playing all of its audio.
|
||||
|
||||
No close message is needed: Gradium signals completion with an
|
||||
``end_of_stream`` message (handled in ``_receive_messages``), after
|
||||
which the server-side context is already closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages, demultiplexing by client_req_id."""
|
||||
# TODO(laurent): This should not be necessary as it should happen when
|
||||
|
||||
@@ -681,28 +681,23 @@ class InworldTTSService(AudioContextTTSService):
|
||||
|
||||
return word_times
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle an interruption from the Inworld WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The interruption frame.
|
||||
direction: The direction of the interruption.
|
||||
"""
|
||||
old_context_id = self.get_active_audio_context_id()
|
||||
logger.trace(f"{self}: Handling interruption, old context: {old_context_id}")
|
||||
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
if old_context_id and self._websocket:
|
||||
logger.trace(f"{self}: Closing context {old_context_id} due to interruption")
|
||||
async def _close_context(self, context_id: str):
|
||||
if context_id and self._websocket:
|
||||
logger.info(f"{self}: Closing context {context_id} due to interruption or completion")
|
||||
try:
|
||||
await self._send_close_context(old_context_id)
|
||||
await self._send_close_context(context_id)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
logger.trace(f"{self}: Interruption handled, context reset to None")
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Callback invoked when an audio context has been interrupted."""
|
||||
await self._close_context(context_id)
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Callback invoked when an audio context has been completed."""
|
||||
await self._close_context(context_id)
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get the websocket for the Inworld WebSocket TTS service.
|
||||
|
||||
@@ -18,13 +18,11 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
@@ -247,16 +245,19 @@ class ResembleAITTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by stopping current synthesis.
|
||||
|
||||
Args:
|
||||
frame: The interruption frame.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Stop metrics when the bot is interrupted."""
|
||||
await self.stop_all_metrics()
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Stop metrics after the Resemble AI context finishes playing.
|
||||
|
||||
No close message is needed: Resemble AI signals completion with an
|
||||
``audio_end`` message (handled in ``_process_messages``), after which
|
||||
the server-side context is already closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
|
||||
@@ -458,14 +458,25 @@ class RimeTTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by clearing current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
async def _close_context(self, context_id: str):
|
||||
"""Clear the Rime speech queue and stop metrics."""
|
||||
await self.stop_all_metrics()
|
||||
if context_id:
|
||||
await self._get_websocket().send(json.dumps(self._build_clear_msg()))
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Clear the Rime speech queue and stop metrics when the bot is interrupted."""
|
||||
await self._close_context(context_id)
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Clear server-side state and stop metrics after the Rime context finishes playing.
|
||||
|
||||
Rime does not send a server-side completion signal (e.g. ``done`` / ``end_of_stream`` /
|
||||
``audio_end``), so we explicitly send a ``clear`` message to clean up
|
||||
any residual server-side state once all audio has been delivered.
|
||||
"""
|
||||
await self._close_context(context_id)
|
||||
|
||||
def _calculate_word_times(self, words: list, starts: list, ends: list) -> list:
|
||||
"""Calculate word timing pairs with proper spacing and punctuation.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user