diff --git a/tests/test_frame_processor.py b/tests/test_frame_processor.py index 3a47520b3..2ce4b7880 100644 --- a/tests/test_frame_processor.py +++ b/tests/test_frame_processor.py @@ -9,6 +9,8 @@ import unittest from dataclasses import dataclass, field from typing import List +from loguru import logger + from pipecat.frames.frames import ( DataFrame, EndFrame, @@ -22,7 +24,11 @@ from pipecat.frames.frames import ( ) from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.filters.identity_filter import IdentityFilter -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.frame_processor import ( + INTERRUPTION_COMPLETION_TIMEOUT, + FrameDirection, + FrameProcessor, +) from pipecat.tests.utils import SleepFrame, run_test @@ -449,6 +455,109 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase): stop_frames = [f for f in received_frames if isinstance(f, StopFrame)] self.assertEqual(len(stop_frames), 1, "StopFrame should survive interruption") + async def test_interruption_frame_complete_sets_event(self): + """Test that InterruptionFrame.complete() sets the event.""" + event = asyncio.Event() + frame = InterruptionFrame(event=event) + self.assertFalse(event.is_set()) + frame.complete() + self.assertTrue(event.is_set()) + + async def test_interruption_frame_complete_without_event(self): + """Test that InterruptionFrame.complete() is safe without an event.""" + frame = InterruptionFrame() + frame.complete() # Should not raise + + async def test_interruption_event_set_at_pipeline_sink(self): + """Test that the event from push_interruption_task_frame_and_wait() + is set when the InterruptionFrame reaches the pipeline sink.""" + event_was_set = False + + class InterruptOnTextProcessor(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + nonlocal event_was_set + + await super().process_frame(frame, direction) + if isinstance(frame, TextFrame): + await self.push_interruption_task_frame_and_wait() + + event_was_set = True + await self.push_frame(OutputTransportMessageUrgentFrame(message="done")) + else: + await self.push_frame(frame, direction) + + pipeline = Pipeline([InterruptOnTextProcessor()]) + + frames_to_send = [ + TextFrame(text="trigger"), + ] + expected_down_frames = [ + InterruptionFrame, + OutputTransportMessageUrgentFrame, + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.assertTrue(event_was_set, "Event should be set after InterruptionFrame completes") + + async def test_interruption_completion_timeout_warning(self): + """Test that a warning is logged when an InterruptionFrame is blocked + and never reaches the pipeline sink.""" + warnings = [] + handler_id = logger.add( + lambda msg: warnings.append(str(msg)), level="WARNING", format="{message}" + ) + + try: + + class BlockInterruptionProcessor(FrameProcessor): + """Blocks InterruptionFrames, completing them after a delay.""" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, InterruptionFrame): + # Complete after the timeout so the warning fires + # but the test doesn't hang. + async def delayed_complete(): + await asyncio.sleep(INTERRUPTION_COMPLETION_TIMEOUT + 1.0) + frame.complete() + + asyncio.create_task(delayed_complete()) + return + await self.push_frame(frame, direction) + + class InterruptOnTextProcessor(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, TextFrame): + await self.push_interruption_task_frame_and_wait() + await self.push_frame(OutputTransportMessageUrgentFrame(message="done")) + else: + await self.push_frame(frame, direction) + + pipeline = Pipeline([BlockInterruptionProcessor(), InterruptOnTextProcessor()]) + + frames_to_send = [ + TextFrame(text="trigger"), + ] + expected_down_frames = [ + OutputTransportMessageUrgentFrame, + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + finally: + logger.remove(handler_id) + + self.assertTrue( + any("InterruptionFrame has not completed" in w for w in warnings), + "Expected a timeout warning about InterruptionFrame not completing", + ) + if __name__ == "__main__": unittest.main()