Add optional direction parameter to PipelineTask.queue_frame() and queue_frames()
Allow pushing frames upstream through the pipeline by passing FrameDirection.UPSTREAM. Downstream frames use the existing push queue, while upstream frames are pushed directly from the pipeline sink.
This commit is contained in:
@@ -389,12 +389,12 @@ class PipelineTask(BasePipelineTask):
|
||||
# source allows us to receive and react to upstream frames, and the sink
|
||||
# allows us to receive and react to downstream frames.
|
||||
source = PipelineSource(self._source_push_frame, name=f"{self}::Source")
|
||||
sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
|
||||
self._sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
|
||||
# Only prepend the RTVIProcessor if we created it ourselves. When the
|
||||
# user already placed it inside their pipeline we must not insert it
|
||||
# again or it will appear twice in the frame chain.
|
||||
processors = [self._rtvi, pipeline] if prepend_rtvi else [pipeline]
|
||||
self._pipeline = Pipeline(processors, source=source, sink=sink)
|
||||
self._pipeline = Pipeline(processors, source=source, sink=self._sink)
|
||||
|
||||
# The task observer acts as a proxy to the provided observers. This way,
|
||||
# we only need to pass a single observer (using the StartFrame) which
|
||||
@@ -625,26 +625,43 @@ class PipelineTask(BasePipelineTask):
|
||||
self._finished = True
|
||||
logger.debug(f"Pipeline task {self} has finished")
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
"""Queue a single frame to be pushed down the pipeline.
|
||||
async def queue_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Queue a single frame to be pushed through the pipeline.
|
||||
|
||||
Downstream frames are pushed from the beginning of the pipeline.
|
||||
Upstream frames are pushed from the end of the pipeline.
|
||||
|
||||
Args:
|
||||
frame: The frame to be processed.
|
||||
direction: The direction to push the frame. Defaults to downstream.
|
||||
"""
|
||||
await self._push_queue.put(frame)
|
||||
if direction == FrameDirection.DOWNSTREAM:
|
||||
await self._push_queue.put(frame)
|
||||
else:
|
||||
await self._sink.queue_frame(frame, direction)
|
||||
|
||||
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
|
||||
"""Queues multiple frames to be pushed down the pipeline.
|
||||
async def queue_frames(
|
||||
self,
|
||||
frames: Iterable[Frame] | AsyncIterable[Frame],
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
):
|
||||
"""Queue multiple frames to be pushed through the pipeline.
|
||||
|
||||
Downstream frames are pushed from the beginning of the pipeline.
|
||||
Upstream frames are pushed from the end of the pipeline.
|
||||
|
||||
Args:
|
||||
frames: An iterable or async iterable of frames to be processed.
|
||||
direction: The direction to push the frames. Defaults to downstream.
|
||||
"""
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
await self.queue_frame(frame)
|
||||
await self.queue_frame(frame, direction)
|
||||
elif isinstance(frames, Iterable):
|
||||
for frame in frames:
|
||||
await self.queue_frame(frame)
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def _cancel(self, *, reason: Optional[str] = None):
|
||||
"""Internal cancellation logic for the pipeline task.
|
||||
|
||||
@@ -965,7 +965,7 @@ class FrameProcessor(BaseObject):
|
||||
try:
|
||||
timestamp = self._clock.get_time() if self._clock else 0
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
logger.trace(f"Pushing {frame} downstream from {self} to {self._next}")
|
||||
|
||||
if self._observer:
|
||||
data = FramePushed(
|
||||
|
||||
@@ -292,6 +292,63 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
assert upstream_received
|
||||
assert downstream_received
|
||||
|
||||
async def test_task_queue_frame_upstream(self):
|
||||
upstream_received = False
|
||||
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
task = PipelineTask(pipeline, cancel_on_idle_timeout=False)
|
||||
task.set_reached_upstream_filter((TextFrame,))
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
async def on_frame_reached_upstream(task, frame):
|
||||
nonlocal upstream_received
|
||||
if isinstance(frame, TextFrame) and frame.text == "Hello Upstream!":
|
||||
upstream_received = True
|
||||
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(task, frame):
|
||||
await task.queue_frame(TextFrame(text="Hello Upstream!"), FrameDirection.UPSTREAM)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
task.run(PipelineTaskParams(loop=asyncio.get_event_loop())),
|
||||
timeout=1.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
assert upstream_received
|
||||
|
||||
async def test_task_queue_frames_upstream(self):
|
||||
upstream_texts = []
|
||||
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
task = PipelineTask(pipeline, cancel_on_idle_timeout=False)
|
||||
task.set_reached_upstream_filter((TextFrame,))
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
async def on_frame_reached_upstream(task, frame):
|
||||
if isinstance(frame, TextFrame):
|
||||
upstream_texts.append(frame.text)
|
||||
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(task, frame):
|
||||
await task.queue_frames(
|
||||
[TextFrame(text="First"), TextFrame(text="Second")],
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
task.run(PipelineTaskParams(loop=asyncio.get_event_loop())),
|
||||
timeout=1.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
assert "First" in upstream_texts
|
||||
assert "Second" in upstream_texts
|
||||
|
||||
async def test_task_heartbeats(self):
|
||||
heartbeats_counter = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user