From 3c970a3ceecd8cc2287441c41bd5e6a35525ccef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 21 Jan 2025 09:43:57 -0800 Subject: [PATCH] tests: add more filter tests --- tests/test_filters.py | 54 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/test_filters.py b/tests/test_filters.py index 706ca1cca..e831f4071 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -4,16 +4,22 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import asyncio import unittest from pipecat.frames.frames import ( + EndFrame, + Frame, + TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) +from pipecat.processors.filters.frame_filter import FrameFilter +from pipecat.processors.filters.function_filter import FunctionFilter from pipecat.processors.filters.identity_filter import IdentityFilter from pipecat.processors.filters.wake_check_filter import WakeCheckFilter -from tests.utils import run_test +from tests.utils import EndTestFrame, run_test class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase): @@ -24,6 +30,52 @@ class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase): await run_test(filter, frames_to_send, expected_returned_frames) +class TestFrameFilter(unittest.IsolatedAsyncioTestCase): + async def test_text_frame(self): + filter = FrameFilter(types=(TextFrame, EndTestFrame)) + frames_to_send = [TextFrame(text="Hello Pipecat!")] + expected_returned_frames = [TextFrame] + await run_test(filter, frames_to_send, expected_returned_frames) + + async def test_end_frame(self): + filter = FrameFilter(types=(EndFrame, EndTestFrame)) + frames_to_send = [EndFrame()] + expected_returned_frames = [EndFrame] + await run_test(filter, frames_to_send, expected_returned_frames) + + async def test_system_frame(self): + filter = FrameFilter(types=(EndTestFrame,)) + frames_to_send = [UserStartedSpeakingFrame()] + expected_returned_frames = [UserStartedSpeakingFrame] + await run_test(filter, frames_to_send, expected_returned_frames) + + +class TestFunctionFilter(unittest.IsolatedAsyncioTestCase): + async def test_passthrough(self): + async def passthrough(frame: Frame): + return True + + filter = FunctionFilter(filter=passthrough) + frames_to_send = [TextFrame(text="Hello Pipecat!")] + expected_returned_frames = [TextFrame] + await run_test(filter, frames_to_send, expected_returned_frames) + + async def test_no_passthrough(self): + async def no_passthrough(frame: Frame): + return False + + filter = FunctionFilter(filter=no_passthrough) + frames_to_send = [TextFrame(text="Hello Pipecat!")] + expected_returned_frames = [TextFrame] + try: + await asyncio.wait_for( + run_test(filter, frames_to_send, expected_returned_frames), timeout=0.5 + ) + assert False + except asyncio.TimeoutError: + pass + + class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase): async def test_no_wake_word(self): filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"])