tests: add more filter tests
This commit is contained in:
@@ -4,16 +4,22 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.filters.frame_filter import FrameFilter
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.filters.wake_check_filter import WakeCheckFilter
|
||||
from tests.utils import run_test
|
||||
from tests.utils import EndTestFrame, run_test
|
||||
|
||||
|
||||
class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -24,6 +30,52 @@ class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase):
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
|
||||
class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_text_frame(self):
|
||||
filter = FrameFilter(types=(TextFrame, EndTestFrame))
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_end_frame(self):
|
||||
filter = FrameFilter(types=(EndFrame, EndTestFrame))
|
||||
frames_to_send = [EndFrame()]
|
||||
expected_returned_frames = [EndFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_system_frame(self):
|
||||
filter = FrameFilter(types=(EndTestFrame,))
|
||||
frames_to_send = [UserStartedSpeakingFrame()]
|
||||
expected_returned_frames = [UserStartedSpeakingFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
|
||||
class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_passthrough(self):
|
||||
async def passthrough(frame: Frame):
|
||||
return True
|
||||
|
||||
filter = FunctionFilter(filter=passthrough)
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_no_passthrough(self):
|
||||
async def no_passthrough(frame: Frame):
|
||||
return False
|
||||
|
||||
filter = FunctionFilter(filter=no_passthrough)
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
run_test(filter, frames_to_send, expected_returned_frames), timeout=0.5
|
||||
)
|
||||
assert False
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_no_wake_word(self):
|
||||
filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"])
|
||||
|
||||
Reference in New Issue
Block a user