diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index eeb39c9b6..deae6290c 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -389,12 +389,12 @@ class PipelineTask(BasePipelineTask): # source allows us to receive and react to upstream frames, and the sink # allows us to receive and react to downstream frames. source = PipelineSource(self._source_push_frame, name=f"{self}::Source") - sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink") + self._sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink") # Only prepend the RTVIProcessor if we created it ourselves. When the # user already placed it inside their pipeline we must not insert it # again or it will appear twice in the frame chain. processors = [self._rtvi, pipeline] if prepend_rtvi else [pipeline] - self._pipeline = Pipeline(processors, source=source, sink=sink) + self._pipeline = Pipeline(processors, source=source, sink=self._sink) # The task observer acts as a proxy to the provided observers. This way, # we only need to pass a single observer (using the StartFrame) which @@ -625,26 +625,43 @@ class PipelineTask(BasePipelineTask): self._finished = True logger.debug(f"Pipeline task {self} has finished") - async def queue_frame(self, frame: Frame): - """Queue a single frame to be pushed down the pipeline. + async def queue_frame( + self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM + ): + """Queue a single frame to be pushed through the pipeline. + + Downstream frames are pushed from the beginning of the pipeline. + Upstream frames are pushed from the end of the pipeline. Args: frame: The frame to be processed. + direction: The direction to push the frame. Defaults to downstream. """ - await self._push_queue.put(frame) + if direction == FrameDirection.DOWNSTREAM: + await self._push_queue.put(frame) + else: + await self._sink.queue_frame(frame, direction) - async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): - """Queues multiple frames to be pushed down the pipeline. + async def queue_frames( + self, + frames: Iterable[Frame] | AsyncIterable[Frame], + direction: FrameDirection = FrameDirection.DOWNSTREAM, + ): + """Queue multiple frames to be pushed through the pipeline. + + Downstream frames are pushed from the beginning of the pipeline. + Upstream frames are pushed from the end of the pipeline. Args: frames: An iterable or async iterable of frames to be processed. + direction: The direction to push the frames. Defaults to downstream. """ if isinstance(frames, AsyncIterable): async for frame in frames: - await self.queue_frame(frame) + await self.queue_frame(frame, direction) elif isinstance(frames, Iterable): for frame in frames: - await self.queue_frame(frame) + await self.queue_frame(frame, direction) async def _cancel(self, *, reason: Optional[str] = None): """Internal cancellation logic for the pipeline task. diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 3e90968fe..3e7b48442 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -965,7 +965,7 @@ class FrameProcessor(BaseObject): try: timestamp = self._clock.get_time() if self._clock else 0 if direction == FrameDirection.DOWNSTREAM and self._next: - logger.trace(f"Pushing {frame} from {self} to {self._next}") + logger.trace(f"Pushing {frame} downstream from {self} to {self._next}") if self._observer: data = FramePushed( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 71121a3fc..04601bf14 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -292,6 +292,63 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase): assert upstream_received assert downstream_received + async def test_task_queue_frame_upstream(self): + upstream_received = False + + pipeline = Pipeline([IdentityFilter()]) + task = PipelineTask(pipeline, cancel_on_idle_timeout=False) + task.set_reached_upstream_filter((TextFrame,)) + + @task.event_handler("on_frame_reached_upstream") + async def on_frame_reached_upstream(task, frame): + nonlocal upstream_received + if isinstance(frame, TextFrame) and frame.text == "Hello Upstream!": + upstream_received = True + + @task.event_handler("on_pipeline_started") + async def on_pipeline_started(task, frame): + await task.queue_frame(TextFrame(text="Hello Upstream!"), FrameDirection.UPSTREAM) + + try: + await asyncio.wait_for( + task.run(PipelineTaskParams(loop=asyncio.get_event_loop())), + timeout=1.0, + ) + except asyncio.TimeoutError: + pass + + assert upstream_received + + async def test_task_queue_frames_upstream(self): + upstream_texts = [] + + pipeline = Pipeline([IdentityFilter()]) + task = PipelineTask(pipeline, cancel_on_idle_timeout=False) + task.set_reached_upstream_filter((TextFrame,)) + + @task.event_handler("on_frame_reached_upstream") + async def on_frame_reached_upstream(task, frame): + if isinstance(frame, TextFrame): + upstream_texts.append(frame.text) + + @task.event_handler("on_pipeline_started") + async def on_pipeline_started(task, frame): + await task.queue_frames( + [TextFrame(text="First"), TextFrame(text="Second")], + FrameDirection.UPSTREAM, + ) + + try: + await asyncio.wait_for( + task.run(PipelineTaskParams(loop=asyncio.get_event_loop())), + timeout=1.0, + ) + except asyncio.TimeoutError: + pass + + assert "First" in upstream_texts + assert "Second" in upstream_texts + async def test_task_heartbeats(self): heartbeats_counter = 0