diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dddd1c0de..fefdfa4b6 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,13 +4,16 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import asyncio import unittest -from pipecat.frames.frames import TextFrame +from pipecat.frames.frames import EndFrame, HeartbeatFrame, TextFrame from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.filters.identity_filter import IdentityFilter -from tests.utils import run_test +from pipecat.processors.frame_processor import FrameProcessor +from tests.utils import HeartbeatsObserver, run_test class TestPipeline(unittest.IsolatedAsyncioTestCase): @@ -48,3 +51,42 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase): frames_to_send = [TextFrame(text="Hello from Pipecat!")] expected_returned_frames = [TextFrame] await run_test(pipeline, frames_to_send, expected_returned_frames) + + +class TestPipelineTask(unittest.IsolatedAsyncioTestCase): + async def test_task_single(self): + pipeline = Pipeline([IdentityFilter()]) + task = PipelineTask(pipeline) + + await task.queue_frame(TextFrame(text="Hello!")) + await task.queue_frames([TextFrame(text="Bye!"), EndFrame()]) + await task.run() + assert task.has_finished() + + async def test_task_heartbeats(self): + heartbeats_counter = 0 + + async def heartbeat_received(processor: FrameProcessor, heartbeat: HeartbeatFrame): + nonlocal heartbeats_counter + heartbeats_counter += 1 + + identity = IdentityFilter() + pipeline = Pipeline([identity]) + heartbeats_observer = HeartbeatsObserver( + target=identity, heartbeat_callback=heartbeat_received + ) + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_heartbeats=True, heartbeats_period_secs=0.2, observers=[heartbeats_observer] + ), + ) + + expected_heartbeats = 1.0 / 0.2 + + await task.queue_frame(TextFrame(text="Hello!")) + try: + await asyncio.wait_for(task.run(), timeout=1.0) + except asyncio.TimeoutError: + pass + assert heartbeats_counter == expected_heartbeats diff --git a/tests/utils.py b/tests/utils.py index 95da9f4aa..3b6a0d009 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,14 +6,16 @@ import asyncio from dataclasses import dataclass -from typing import Sequence, Tuple +from typing import Awaitable, Callable, Sequence, Tuple from pipecat.clocks.system_clock import SystemClock from pipecat.frames.frames import ( ControlFrame, Frame, + HeartbeatFrame, StartFrame, ) +from pipecat.observers.base_observer import BaseObserver from pipecat.processors.frame_processor import FrameDirection, FrameProcessor @@ -22,6 +24,28 @@ class EndTestFrame(ControlFrame): pass +class HeartbeatsObserver(BaseObserver): + def __init__( + self, + *, + target: FrameProcessor, + heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]], + ): + self._target = target + self._callback = heartbeat_callback + + async def on_push_frame( + self, + src: FrameProcessor, + dst: FrameProcessor, + frame: Frame, + direction: FrameDirection, + timestamp: int, + ): + if src == self._target and isinstance(frame, HeartbeatFrame): + await self._callback(self._target, frame) + + class QueuedFrameProcessor(FrameProcessor): def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): super().__init__()