Buffer internal frames during ParallelPipeline lifecycle synchronization
Processors inside parallel sub-pipelines can push frames during StartFrame/EndFrame/CancelFrame processing. Previously these frames could escape the ParallelPipeline before all branches finished processing the lifecycle frame. Now they are buffered and flushed after synchronization completes.
This commit is contained in:
1
changelog/3668.fixed.md
Normal file
1
changelog/3668.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `ParallelPipeline` allowing frames pushed by internal processors to escape during lifecycle frame (`StartFrame`/`EndFrame`/`CancelFrame`) synchronization. These frames are now buffered and flushed after all branches complete.
|
||||
@@ -52,6 +52,8 @@ class ParallelPipeline(BasePipeline):
|
||||
|
||||
self._seen_ids = set()
|
||||
self._frame_counter: Dict[int, int] = {}
|
||||
self._synchronizing: bool = False
|
||||
self._buffered_frames: list[tuple[Frame, FrameDirection]] = []
|
||||
|
||||
logger.debug(f"Creating {self} pipelines")
|
||||
|
||||
@@ -143,6 +145,7 @@ class ParallelPipeline(BasePipeline):
|
||||
# Parallel pipeline synchronized frames.
|
||||
if isinstance(frame, (StartFrame, EndFrame, CancelFrame)):
|
||||
self._frame_counter[frame.id] = len(self._pipelines)
|
||||
self._synchronizing = True
|
||||
await self.pause_processing_system_frames()
|
||||
await self.pause_processing_frames()
|
||||
|
||||
@@ -151,10 +154,18 @@ class ParallelPipeline(BasePipeline):
|
||||
await p.queue_frame(frame, direction)
|
||||
|
||||
async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Push frames while avoiding duplicates using frame ID tracking."""
|
||||
"""Push frames while avoiding duplicates using frame ID tracking.
|
||||
|
||||
During lifecycle frame synchronization, non-lifecycle frames are buffered
|
||||
to prevent them from escaping the parallel pipeline before all branches
|
||||
have finished processing the lifecycle frame.
|
||||
"""
|
||||
if frame.id not in self._seen_ids:
|
||||
self._seen_ids.add(frame.id)
|
||||
await self.push_frame(frame, direction)
|
||||
if self._synchronizing:
|
||||
self._buffered_frames.append((frame, direction))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _pipeline_sink_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
# Parallel pipeline synchronized frames.
|
||||
@@ -167,8 +178,17 @@ class ParallelPipeline(BasePipeline):
|
||||
|
||||
# Only push the frame when all pipelines have processed it.
|
||||
if frame_counter == 0:
|
||||
self._synchronizing = False
|
||||
await self._parallel_push_frame(frame, direction)
|
||||
await self._flush_buffered_frames()
|
||||
await self.resume_processing_system_frames()
|
||||
await self.resume_processing_frames()
|
||||
else:
|
||||
await self._parallel_push_frame(frame, direction)
|
||||
|
||||
async def _flush_buffered_frames(self):
|
||||
"""Flush frames that were buffered during lifecycle frame synchronization."""
|
||||
frames = self._buffered_frames
|
||||
self._buffered_frames = []
|
||||
for frame, direction in frames:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -96,6 +96,34 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_parallel_internal_frames_buffered_during_start(self):
|
||||
"""Frames pushed by internal processors during StartFrame processing
|
||||
should be buffered and only released after StartFrame synchronization
|
||||
completes."""
|
||||
|
||||
class EmitOnStartProcessor(FrameProcessor):
|
||||
"""Pushes a TextFrame when it receives a StartFrame."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.push_frame(TextFrame(text="from start"))
|
||||
|
||||
pipeline = ParallelPipeline([EmitOnStartProcessor()], [IdentityFilter()])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello!")]
|
||||
|
||||
# StartFrame should come first, then the TextFrame emitted during
|
||||
# StartFrame processing, then the regular TextFrame.
|
||||
expected_down_frames = [StartFrame, TextFrame, TextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
ignore_start=False,
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_task_single(self):
|
||||
|
||||
Reference in New Issue
Block a user