tests: added PipelineTask tests

This commit is contained in:
Aleix Conchillo Flaqué
2025-01-21 11:44:13 -08:00
parent ab4221a4db
commit 401d3ff267
2 changed files with 69 additions and 3 deletions

View File

@@ -4,13 +4,16 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import unittest
from pipecat.frames.frames import TextFrame
from pipecat.frames.frames import EndFrame, HeartbeatFrame, TextFrame
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.filters.identity_filter import IdentityFilter
from tests.utils import run_test
from pipecat.processors.frame_processor import FrameProcessor
from tests.utils import HeartbeatsObserver, run_test
class TestPipeline(unittest.IsolatedAsyncioTestCase):
@@ -48,3 +51,42 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
expected_returned_frames = [TextFrame]
await run_test(pipeline, frames_to_send, expected_returned_frames)
class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
async def test_task_single(self):
pipeline = Pipeline([IdentityFilter()])
task = PipelineTask(pipeline)
await task.queue_frame(TextFrame(text="Hello!"))
await task.queue_frames([TextFrame(text="Bye!"), EndFrame()])
await task.run()
assert task.has_finished()
async def test_task_heartbeats(self):
heartbeats_counter = 0
async def heartbeat_received(processor: FrameProcessor, heartbeat: HeartbeatFrame):
nonlocal heartbeats_counter
heartbeats_counter += 1
identity = IdentityFilter()
pipeline = Pipeline([identity])
heartbeats_observer = HeartbeatsObserver(
target=identity, heartbeat_callback=heartbeat_received
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_heartbeats=True, heartbeats_period_secs=0.2, observers=[heartbeats_observer]
),
)
expected_heartbeats = 1.0 / 0.2
await task.queue_frame(TextFrame(text="Hello!"))
try:
await asyncio.wait_for(task.run(), timeout=1.0)
except asyncio.TimeoutError:
pass
assert heartbeats_counter == expected_heartbeats

View File

@@ -6,14 +6,16 @@
import asyncio
from dataclasses import dataclass
from typing import Sequence, Tuple
from typing import Awaitable, Callable, Sequence, Tuple
from pipecat.clocks.system_clock import SystemClock
from pipecat.frames.frames import (
ControlFrame,
Frame,
HeartbeatFrame,
StartFrame,
)
from pipecat.observers.base_observer import BaseObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -22,6 +24,28 @@ class EndTestFrame(ControlFrame):
pass
class HeartbeatsObserver(BaseObserver):
def __init__(
self,
*,
target: FrameProcessor,
heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]],
):
self._target = target
self._callback = heartbeat_callback
async def on_push_frame(
self,
src: FrameProcessor,
dst: FrameProcessor,
frame: Frame,
direction: FrameDirection,
timestamp: int,
):
if src == self._target and isinstance(frame, HeartbeatFrame):
await self._callback(self._target, frame)
class QueuedFrameProcessor(FrameProcessor):
def __init__(self, queue: asyncio.Queue, ignore_start: bool = True):
super().__init__()