tests: added PipelineTask tests
This commit is contained in:
@@ -4,13 +4,16 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.frames.frames import EndFrame, HeartbeatFrame, TextFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from tests.utils import run_test
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from tests.utils import HeartbeatsObserver, run_test
|
||||
|
||||
|
||||
class TestPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -48,3 +51,42 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(pipeline, frames_to_send, expected_returned_frames)
|
||||
|
||||
|
||||
class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_task_single(self):
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
await task.queue_frame(TextFrame(text="Hello!"))
|
||||
await task.queue_frames([TextFrame(text="Bye!"), EndFrame()])
|
||||
await task.run()
|
||||
assert task.has_finished()
|
||||
|
||||
async def test_task_heartbeats(self):
|
||||
heartbeats_counter = 0
|
||||
|
||||
async def heartbeat_received(processor: FrameProcessor, heartbeat: HeartbeatFrame):
|
||||
nonlocal heartbeats_counter
|
||||
heartbeats_counter += 1
|
||||
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
heartbeats_observer = HeartbeatsObserver(
|
||||
target=identity, heartbeat_callback=heartbeat_received
|
||||
)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_heartbeats=True, heartbeats_period_secs=0.2, observers=[heartbeats_observer]
|
||||
),
|
||||
)
|
||||
|
||||
expected_heartbeats = 1.0 / 0.2
|
||||
|
||||
await task.queue_frame(TextFrame(text="Hello!"))
|
||||
try:
|
||||
await asyncio.wait_for(task.run(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
assert heartbeats_counter == expected_heartbeats
|
||||
|
||||
@@ -6,14 +6,16 @@
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Awaitable, Callable, Sequence, Tuple
|
||||
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
@@ -22,6 +24,28 @@ class EndTestFrame(ControlFrame):
|
||||
pass
|
||||
|
||||
|
||||
class HeartbeatsObserver(BaseObserver):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target: FrameProcessor,
|
||||
heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]],
|
||||
):
|
||||
self._target = target
|
||||
self._callback = heartbeat_callback
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
if src == self._target and isinstance(frame, HeartbeatFrame):
|
||||
await self._callback(self._target, frame)
|
||||
|
||||
|
||||
class QueuedFrameProcessor(FrameProcessor):
|
||||
def __init__(self, queue: asyncio.Queue, ignore_start: bool = True):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user