Merge pull request #2386 from pipecat-ai/filipi/parallel_pipeline
Only push the StartFrame when all parallel pipelines have processed it
This commit is contained in:
@@ -144,6 +144,8 @@ class ParallelPipeline(BasePipeline):
|
||||
|
||||
self._seen_ids = set()
|
||||
self._endframe_counter: Dict[int, int] = {}
|
||||
self._start_frame_counter: Dict[int, int] = {}
|
||||
self._started = False
|
||||
|
||||
self._up_task = None
|
||||
self._down_task = None
|
||||
@@ -185,7 +187,7 @@ class ParallelPipeline(BasePipeline):
|
||||
|
||||
# We will add a source before the pipeline and a sink after.
|
||||
source = ParallelPipelineSource(self._up_queue, self._parallel_push_frame)
|
||||
sink = ParallelPipelineSink(self._down_queue, self._parallel_push_frame)
|
||||
sink = ParallelPipelineSink(self._down_queue, self._pipeline_sink_push_frame)
|
||||
self._sources.append(source)
|
||||
self._sinks.append(sink)
|
||||
|
||||
@@ -218,7 +220,7 @@ class ParallelPipeline(BasePipeline):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
self._start_frame_counter[frame.id] = len(self._pipelines)
|
||||
elif isinstance(frame, EndFrame):
|
||||
self._endframe_counter[frame.id] = len(self._pipelines)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
@@ -297,6 +299,25 @@ class ParallelPipeline(BasePipeline):
|
||||
self._seen_ids.add(frame.id)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _pipeline_sink_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, StartFrame):
|
||||
# Decrement counter and check if all pipelines have processed the StartFrame
|
||||
start_frame_counter = self._start_frame_counter.get(frame.id, 0)
|
||||
if start_frame_counter > 0:
|
||||
self._start_frame_counter[frame.id] -= 1
|
||||
start_frame_counter = self._start_frame_counter[frame.id]
|
||||
|
||||
# Only push the StartFrame when all pipelines have processed it
|
||||
if start_frame_counter == 0:
|
||||
self._started = True
|
||||
await self._start(frame)
|
||||
await self._parallel_push_frame(frame, direction)
|
||||
else:
|
||||
if self._started:
|
||||
await self._parallel_push_frame(frame, direction)
|
||||
else:
|
||||
await self._down_queue.put(frame)
|
||||
|
||||
async def _process_up_queue(self):
|
||||
"""Process upstream frames from all parallel branches."""
|
||||
while True:
|
||||
|
||||
@@ -256,6 +256,7 @@ class FrameProcessor(BaseObject):
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_queue = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -782,8 +783,12 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
elif self.__process_queue:
|
||||
await self.__process_queue.put((frame, direction, callback))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self}: __process_queue is None when processing frame {frame.name}"
|
||||
)
|
||||
|
||||
async def __process_frame_task_handler(self):
|
||||
"""Handle non-system frames from the process queue."""
|
||||
|
||||
Reference in New Issue
Block a user