Merge pull request #2386 from pipecat-ai/filipi/parallel_pipeline

Only push the StartFrame when all parallel pipelines have processed it
This commit is contained in:
Filipi da Silva Fuchter
2025-08-07 14:20:30 -03:00
committed by GitHub
2 changed files with 29 additions and 3 deletions

View File

@@ -144,6 +144,8 @@ class ParallelPipeline(BasePipeline):
self._seen_ids = set()
self._endframe_counter: Dict[int, int] = {}
self._start_frame_counter: Dict[int, int] = {}
self._started = False
self._up_task = None
self._down_task = None
@@ -185,7 +187,7 @@ class ParallelPipeline(BasePipeline):
# We will add a source before the pipeline and a sink after.
source = ParallelPipelineSource(self._up_queue, self._parallel_push_frame)
sink = ParallelPipelineSink(self._down_queue, self._parallel_push_frame)
sink = ParallelPipelineSink(self._down_queue, self._pipeline_sink_push_frame)
self._sources.append(source)
self._sinks.append(sink)
@@ -218,7 +220,7 @@ class ParallelPipeline(BasePipeline):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start(frame)
self._start_frame_counter[frame.id] = len(self._pipelines)
elif isinstance(frame, EndFrame):
self._endframe_counter[frame.id] = len(self._pipelines)
elif isinstance(frame, CancelFrame):
@@ -297,6 +299,25 @@ class ParallelPipeline(BasePipeline):
self._seen_ids.add(frame.id)
await self.push_frame(frame, direction)
async def _pipeline_sink_push_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
# Decrement counter and check if all pipelines have processed the StartFrame
start_frame_counter = self._start_frame_counter.get(frame.id, 0)
if start_frame_counter > 0:
self._start_frame_counter[frame.id] -= 1
start_frame_counter = self._start_frame_counter[frame.id]
# Only push the StartFrame when all pipelines have processed it
if start_frame_counter == 0:
self._started = True
await self._start(frame)
await self._parallel_push_frame(frame, direction)
else:
if self._started:
await self._parallel_push_frame(frame, direction)
else:
await self._down_queue.put(frame)
async def _process_up_queue(self):
"""Process upstream frames from all parallel branches."""
while True:

View File

@@ -256,6 +256,7 @@ class FrameProcessor(BaseObject):
self.__should_block_frames = False
self.__process_event = None
self.__process_frame_task: Optional[asyncio.Task] = None
self.__process_queue = None
@property
def id(self) -> int:
@@ -782,8 +783,12 @@ class FrameProcessor(BaseObject):
if isinstance(frame, SystemFrame):
await self.__process_frame(frame, direction, callback)
else:
elif self.__process_queue:
await self.__process_queue.put((frame, direction, callback))
else:
raise RuntimeError(
f"{self}: __process_queue is None when processing frame {frame.name}"
)
async def __process_frame_task_handler(self):
"""Handle non-system frames from the process queue."""