Add tests for InterruptionFrame completion event
Add tests for the event-based interruption completion: complete() sets the event, complete() is safe without an event, the event fires at the pipeline sink, and a warning is logged when the frame is blocked. Also remove the unconditional await after the timeout so the function returns instead of hanging when complete() is never called.
This commit is contained in:
@@ -9,6 +9,8 @@ import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
EndFrame,
|
||||
@@ -22,7 +24,11 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import (
|
||||
INTERRUPTION_COMPLETION_TIMEOUT,
|
||||
FrameDirection,
|
||||
FrameProcessor,
|
||||
)
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
@@ -449,6 +455,109 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
stop_frames = [f for f in received_frames if isinstance(f, StopFrame)]
|
||||
self.assertEqual(len(stop_frames), 1, "StopFrame should survive interruption")
|
||||
|
||||
async def test_interruption_frame_complete_sets_event(self):
|
||||
"""Test that InterruptionFrame.complete() sets the event."""
|
||||
event = asyncio.Event()
|
||||
frame = InterruptionFrame(event=event)
|
||||
self.assertFalse(event.is_set())
|
||||
frame.complete()
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
async def test_interruption_frame_complete_without_event(self):
|
||||
"""Test that InterruptionFrame.complete() is safe without an event."""
|
||||
frame = InterruptionFrame()
|
||||
frame.complete() # Should not raise
|
||||
|
||||
async def test_interruption_event_set_at_pipeline_sink(self):
|
||||
"""Test that the event from push_interruption_task_frame_and_wait()
|
||||
is set when the InterruptionFrame reaches the pipeline sink."""
|
||||
event_was_set = False
|
||||
|
||||
class InterruptOnTextProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
nonlocal event_was_set
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
event_was_set = True
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([InterruptOnTextProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
TextFrame(text="trigger"),
|
||||
]
|
||||
expected_down_frames = [
|
||||
InterruptionFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.assertTrue(event_was_set, "Event should be set after InterruptionFrame completes")
|
||||
|
||||
async def test_interruption_completion_timeout_warning(self):
|
||||
"""Test that a warning is logged when an InterruptionFrame is blocked
|
||||
and never reaches the pipeline sink."""
|
||||
warnings = []
|
||||
handler_id = logger.add(
|
||||
lambda msg: warnings.append(str(msg)), level="WARNING", format="{message}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
class BlockInterruptionProcessor(FrameProcessor):
|
||||
"""Blocks InterruptionFrames, completing them after a delay."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
# Complete after the timeout so the warning fires
|
||||
# but the test doesn't hang.
|
||||
async def delayed_complete():
|
||||
await asyncio.sleep(INTERRUPTION_COMPLETION_TIMEOUT + 1.0)
|
||||
frame.complete()
|
||||
|
||||
asyncio.create_task(delayed_complete())
|
||||
return
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class InterruptOnTextProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([BlockInterruptionProcessor(), InterruptOnTextProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
TextFrame(text="trigger"),
|
||||
]
|
||||
expected_down_frames = [
|
||||
OutputTransportMessageUrgentFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
finally:
|
||||
logger.remove(handler_id)
|
||||
|
||||
self.assertTrue(
|
||||
any("InterruptionFrame has not completed" in w for w in warnings),
|
||||
"Expected a timeout warning about InterruptionFrame not completing",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user