diff --git a/changelog/3668.fixed.md b/changelog/3668.fixed.md new file mode 100644 index 000000000..6885a7591 --- /dev/null +++ b/changelog/3668.fixed.md @@ -0,0 +1 @@ +- Fixed `ParallelPipeline` allowing frames pushed by internal processors to escape during lifecycle frame (`StartFrame`/`EndFrame`/`CancelFrame`) synchronization. These frames are now buffered and flushed after all branches complete. diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index 81beeead8..88ea04638 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -52,6 +52,8 @@ class ParallelPipeline(BasePipeline): self._seen_ids = set() self._frame_counter: Dict[int, int] = {} + self._synchronizing: bool = False + self._buffered_frames: list[tuple[Frame, FrameDirection]] = [] logger.debug(f"Creating {self} pipelines") @@ -143,6 +145,7 @@ class ParallelPipeline(BasePipeline): # Parallel pipeline synchronized frames. if isinstance(frame, (StartFrame, EndFrame, CancelFrame)): self._frame_counter[frame.id] = len(self._pipelines) + self._synchronizing = True await self.pause_processing_system_frames() await self.pause_processing_frames() @@ -151,10 +154,18 @@ class ParallelPipeline(BasePipeline): await p.queue_frame(frame, direction) async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection): - """Push frames while avoiding duplicates using frame ID tracking.""" + """Push frames while avoiding duplicates using frame ID tracking. + + During lifecycle frame synchronization, non-lifecycle frames are buffered + to prevent them from escaping the parallel pipeline before all branches + have finished processing the lifecycle frame. + """ if frame.id not in self._seen_ids: self._seen_ids.add(frame.id) - await self.push_frame(frame, direction) + if self._synchronizing: + self._buffered_frames.append((frame, direction)) + else: + await self.push_frame(frame, direction) async def _pipeline_sink_push_frame(self, frame: Frame, direction: FrameDirection): # Parallel pipeline synchronized frames. @@ -167,8 +178,17 @@ class ParallelPipeline(BasePipeline): # Only push the frame when all pipelines have processed it. if frame_counter == 0: + self._synchronizing = False await self._parallel_push_frame(frame, direction) + await self._flush_buffered_frames() await self.resume_processing_system_frames() await self.resume_processing_frames() else: await self._parallel_push_frame(frame, direction) + + async def _flush_buffered_frames(self): + """Flush frames that were buffered during lifecycle frame synchronization.""" + frames = self._buffered_frames + self._buffered_frames = [] + for frame, direction in frames: + await self.push_frame(frame, direction) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dda9e583d..71121a3fc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -96,6 +96,34 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase): expected_down_frames=expected_down_frames, ) + async def test_parallel_internal_frames_buffered_during_start(self): + """Frames pushed by internal processors during StartFrame processing + should be buffered and only released after StartFrame synchronization + completes.""" + + class EmitOnStartProcessor(FrameProcessor): + """Pushes a TextFrame when it receives a StartFrame.""" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + await self.push_frame(frame, direction) + if isinstance(frame, StartFrame): + await self.push_frame(TextFrame(text="from start")) + + pipeline = ParallelPipeline([EmitOnStartProcessor()], [IdentityFilter()]) + + frames_to_send = [TextFrame(text="Hello!")] + + # StartFrame should come first, then the TextFrame emitted during + # StartFrame processing, then the regular TextFrame. + expected_down_frames = [StartFrame, TextFrame, TextFrame] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ignore_start=False, + ) + class TestPipelineTask(unittest.IsolatedAsyncioTestCase): async def test_task_single(self):