Add tests for InterruptionFrame completion event

Add tests for the event-based interruption completion: complete() sets
the event, complete() is safe without an event, the event fires at
the pipeline sink, and a warning is logged when the frame is blocked.

Also remove the unconditional await after the timeout so the function
returns instead of hanging when complete() is never called.
This commit is contained in:
Aleix Conchillo Flaqué
2026-02-05 22:54:36 -08:00
parent 2345090b10
commit a352b2d7a0

View File

@@ -9,6 +9,8 @@ import unittest
from dataclasses import dataclass, field
from typing import List
from loguru import logger
from pipecat.frames.frames import (
DataFrame,
EndFrame,
@@ -22,7 +24,11 @@ from pipecat.frames.frames import (
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.filters.identity_filter import IdentityFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import (
INTERRUPTION_COMPLETION_TIMEOUT,
FrameDirection,
FrameProcessor,
)
from pipecat.tests.utils import SleepFrame, run_test
@@ -449,6 +455,109 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
stop_frames = [f for f in received_frames if isinstance(f, StopFrame)]
self.assertEqual(len(stop_frames), 1, "StopFrame should survive interruption")
async def test_interruption_frame_complete_sets_event(self):
"""Test that InterruptionFrame.complete() sets the event."""
event = asyncio.Event()
frame = InterruptionFrame(event=event)
self.assertFalse(event.is_set())
frame.complete()
self.assertTrue(event.is_set())
async def test_interruption_frame_complete_without_event(self):
"""Test that InterruptionFrame.complete() is safe without an event."""
frame = InterruptionFrame()
frame.complete() # Should not raise
async def test_interruption_event_set_at_pipeline_sink(self):
"""Test that the event from push_interruption_task_frame_and_wait()
is set when the InterruptionFrame reaches the pipeline sink."""
event_was_set = False
class InterruptOnTextProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
nonlocal event_was_set
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_interruption_task_frame_and_wait()
event_was_set = True
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
else:
await self.push_frame(frame, direction)
pipeline = Pipeline([InterruptOnTextProcessor()])
frames_to_send = [
TextFrame(text="trigger"),
]
expected_down_frames = [
InterruptionFrame,
OutputTransportMessageUrgentFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.assertTrue(event_was_set, "Event should be set after InterruptionFrame completes")
async def test_interruption_completion_timeout_warning(self):
"""Test that a warning is logged when an InterruptionFrame is blocked
and never reaches the pipeline sink."""
warnings = []
handler_id = logger.add(
lambda msg: warnings.append(str(msg)), level="WARNING", format="{message}"
)
try:
class BlockInterruptionProcessor(FrameProcessor):
"""Blocks InterruptionFrames, completing them after a delay."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
# Complete after the timeout so the warning fires
# but the test doesn't hang.
async def delayed_complete():
await asyncio.sleep(INTERRUPTION_COMPLETION_TIMEOUT + 1.0)
frame.complete()
asyncio.create_task(delayed_complete())
return
await self.push_frame(frame, direction)
class InterruptOnTextProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_interruption_task_frame_and_wait()
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
else:
await self.push_frame(frame, direction)
pipeline = Pipeline([BlockInterruptionProcessor(), InterruptOnTextProcessor()])
frames_to_send = [
TextFrame(text="trigger"),
]
expected_down_frames = [
OutputTransportMessageUrgentFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
finally:
logger.remove(handler_id)
self.assertTrue(
any("InterruptionFrame has not completed" in w for w in warnings),
"Expected a timeout warning about InterruptionFrame not completing",
)
if __name__ == "__main__":
unittest.main()