Merge branch 'main' into filipi/refactoring_elevenlabs
This commit is contained in:
1
changelog/4434.fixed.md
Normal file
1
changelog/4434.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed interruptions being delayed when a slow non-uninterruptible frame was processing and an uninterruptible frame was waiting in the queue. The bot would stall until the slow frame finished instead of cancelling it immediately on interruption.
|
||||
1
changelog/4435.fixed.md
Normal file
1
changelog/4435.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `TTSService` dropping uninterruptible frames (e.g. `FunctionCallResultFrame`) from its internal serialization queue when an interruption occurs. Previously, the queue was recreated on every interruption, silently discarding any queued frames. The queue is now reset instead of recreated, preserving uninterruptible frames so they are always delivered downstream.
|
||||
@@ -877,14 +877,19 @@ class FrameProcessor(BaseObject):
|
||||
current_is_uninterruptible = isinstance(
|
||||
self.__process_current_frame, UninterruptibleFrame
|
||||
)
|
||||
if current_is_uninterruptible or self.__process_queue.has_uninterruptible:
|
||||
# We don't want to cancel an UninterruptibleFrame (either the
|
||||
# one currently being processed or one waiting in the queue),
|
||||
# so we simply cleanup the queue keeping only
|
||||
# UninterruptibleFrames.
|
||||
if current_is_uninterruptible:
|
||||
# The frame currently being processed is uninterruptible, so we
|
||||
# must not cancel it. Just flush non-uninterruptible frames from
|
||||
# the queue; any uninterruptible ones will be kept and processed
|
||||
# after the current frame finishes.
|
||||
self.__reset_process_queue()
|
||||
else:
|
||||
# Cancel and re-create the process task.
|
||||
# Cancel and re-create the process task. Previously this branch
|
||||
# was skipped when the queue contained an uninterruptible frame,
|
||||
# which caused slow non-uninterruptible frames to block
|
||||
# interruptions. Uninterruptible queued frames are safe here
|
||||
# because __create_process_task calls __reset_process_queue
|
||||
# internally, which always preserves them.
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
|
||||
@@ -50,6 +50,7 @@ from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import TTSSettings, is_given
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.frame_queue import FrameQueue
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
from pipecat.utils.time import seconds_to_nanoseconds
|
||||
@@ -325,6 +326,20 @@ class TTSService(AIService):
|
||||
self._audio_contexts: dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task: asyncio.Task | None = None
|
||||
|
||||
# Single FIFO queue that serializes everything the TTS service emits downstream.
|
||||
# Items can be:
|
||||
# str – an audio context ID: process the per-context audio queue in full before
|
||||
# moving on (see _handle_audio_context).
|
||||
# Frame – a non-system downstream frame (e.g. AggregatedTextFrame, FooFrame) that
|
||||
# must be emitted in-order relative to surrounding audio contexts.
|
||||
# None – shutdown sentinel (sent by stop()).
|
||||
# Created once here so it survives interruptions: on interruption we call reset()
|
||||
# which drops non-UninterruptibleFrame items while keeping uninterruptible ones
|
||||
# (e.g. FunctionCallResultFrame) that must not be lost mid-flight.
|
||||
self._serialization_queue: FrameQueue = FrameQueue(
|
||||
frame_getter=lambda item: item if isinstance(item, Frame) else None
|
||||
)
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
@@ -875,6 +890,9 @@ class TTSService(AIService):
|
||||
await self.reset_word_timestamps()
|
||||
|
||||
await self._stop_audio_context_task()
|
||||
# Drops non-UninterruptibleFrame items while keeping uninterruptible ones
|
||||
# (e.g. FunctionCallResultFrame) that must not be lost mid-flight.
|
||||
self._serialization_queue.reset()
|
||||
audio_contexts = self.get_audio_contexts()
|
||||
if audio_contexts:
|
||||
for ctx_id in audio_contexts:
|
||||
@@ -1298,14 +1316,6 @@ class TTSService(AIService):
|
||||
|
||||
def _create_audio_context_task(self):
|
||||
if not self._audio_context_task:
|
||||
# Single FIFO queue that serializes everything the TTS service emits downstream.
|
||||
# Items can be:
|
||||
# str – an audio context ID: process the per-context audio queue in full before
|
||||
# moving on (see _handle_audio_context).
|
||||
# Frame – a non-system downstream frame (e.g. AggregatedTextFrame, FooFrame) that
|
||||
# must be emitted in-order relative to surrounding audio contexts.
|
||||
# None – shutdown sentinel (sent by stop()).
|
||||
self._serialization_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._audio_contexts: dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = self.create_task(self._audio_context_task_handler())
|
||||
|
||||
|
||||
@@ -852,5 +852,108 @@ async def test_no_deadlock_on_interrupt_before_audio_with_uninterruptible():
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization queue interruption tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockBlockingWebSocketTTSService(TTSService):
|
||||
"""WebSocket TTS that creates an audio context but never delivers audio.
|
||||
|
||||
The audio context consumer blocks indefinitely on the per-context queue,
|
||||
allowing subsequent frames to accumulate in the serialization queue.
|
||||
pause_frame_processing=False so frames after TTSSpeakFrame enter the
|
||||
serialization queue directly rather than stalling in the FrameProcessor.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
if False:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialization_queue_drops_regular_frames_on_interruption():
|
||||
"""Regular frames in the serialization queue are dropped on interruption.
|
||||
|
||||
While the audio context consumer is blocked (no audio delivered), a FooFrame
|
||||
enters the serialization queue. When InterruptionFrame arrives, the queue is
|
||||
reset and the FooFrame must not appear downstream.
|
||||
"""
|
||||
tts = MockBlockingWebSocketTTSService()
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="hello", append_to_context=False),
|
||||
SleepFrame(sleep=0.05), # let audio context task start blocking
|
||||
FooFrame(label="will_be_dropped"),
|
||||
SleepFrame(sleep=0.05), # let FooFrame enter the serialization queue
|
||||
InterruptionFrame(),
|
||||
SleepFrame(sleep=0.1), # let interruption handling complete
|
||||
]
|
||||
|
||||
frames_received = await asyncio.wait_for(
|
||||
run_test(tts, frames_to_send=frames_to_send),
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
down = frames_received[0]
|
||||
foo_frames = [f for f in down if isinstance(f, FooFrame)]
|
||||
assert len(foo_frames) == 0, (
|
||||
f"FooFrame should be dropped on interruption, but {len(foo_frames)} arrived downstream"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialization_queue_preserves_uninterruptible_frames_on_interruption():
|
||||
"""Uninterruptible frames in the serialization queue survive interruption.
|
||||
|
||||
While the audio context consumer is blocked, both a regular FooFrame and an
|
||||
UninterruptibleMarkerFrame enter the serialization queue. When InterruptionFrame
|
||||
arrives, reset() drops FooFrame but keeps UninterruptibleMarkerFrame, which
|
||||
the new audio context task then pushes downstream.
|
||||
"""
|
||||
tts = MockBlockingWebSocketTTSService()
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="hello", append_to_context=False),
|
||||
SleepFrame(sleep=0.05), # let audio context task start blocking
|
||||
FooFrame(label="will_be_dropped"),
|
||||
UninterruptibleMarkerFrame(label="must_survive"),
|
||||
SleepFrame(sleep=0.05), # let frames enter the serialization queue
|
||||
InterruptionFrame(),
|
||||
SleepFrame(sleep=0.1), # let interruption handling and new task run
|
||||
]
|
||||
|
||||
frames_received = await asyncio.wait_for(
|
||||
run_test(tts, frames_to_send=frames_to_send),
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
down = frames_received[0]
|
||||
|
||||
foo_frames = [f for f in down if isinstance(f, FooFrame)]
|
||||
assert len(foo_frames) == 0, (
|
||||
f"FooFrame should be dropped on interruption, but {len(foo_frames)} arrived downstream"
|
||||
)
|
||||
|
||||
uninterruptible_frames = [f for f in down if isinstance(f, UninterruptibleMarkerFrame)]
|
||||
assert len(uninterruptible_frames) == 1, (
|
||||
f"UninterruptibleMarkerFrame must survive interruption, "
|
||||
f"but {len(uninterruptible_frames)} arrived downstream"
|
||||
)
|
||||
assert uninterruptible_frames[0].label == "must_survive"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user