Add unit tests for ServiceSwitcher
This commit is contained in:
254
tests/test_service_switcher.py
Normal file
254
tests/test_service_switcher.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Unit tests for ServiceSwitcher and related components."""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, TextFrame
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
class MockFrameProcessor(FrameProcessor):
|
||||
"""A test frame processor that tracks which frames it has processed."""
|
||||
|
||||
def __init__(self, test_name: str, **kwargs):
|
||||
"""Initialize the test processor with a name.
|
||||
|
||||
Args:
|
||||
test_name: A unique name for this processor instance.
|
||||
**kwargs: Additional arguments passed to the parent FrameProcessor.
|
||||
"""
|
||||
super().__init__(name=test_name, **kwargs)
|
||||
self.test_name = test_name
|
||||
self.processed_frames = []
|
||||
self.frame_count = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an incoming frame and track it.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
self.processed_frames.append(frame)
|
||||
self.frame_count += 1
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def reset_counters(self):
|
||||
"""Reset the frame tracking counters."""
|
||||
self.processed_frames = []
|
||||
self.frame_count = 0
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_services(self):
|
||||
"""Test initialization with a list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
self.assertEqual(strategy.services, self.services)
|
||||
self.assertEqual(strategy.active_service, self.service1) # First service should be active
|
||||
|
||||
def test_init_with_empty_services(self):
|
||||
"""Test initialization with an empty list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual([])
|
||||
|
||||
self.assertEqual(strategy.services, [])
|
||||
self.assertIsNone(strategy.active_service)
|
||||
|
||||
def test_handle_frame_manually_switch_service(self):
|
||||
"""Test manual service switching with ManuallySwitchServiceFrame."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
# Initially service1 should be active
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
self.assertNotEqual(strategy.active_service, self.service3)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
def test_handle_frame_invalid_service(self):
|
||||
"""Test that switching to an invalid service raises an error."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
invalid_service = MockFrameProcessor("invalid")
|
||||
|
||||
switch_frame = ManuallySwitchServiceFrame(service=invalid_service)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIn("Service", str(context.exception))
|
||||
self.assertIn("is not in the list of available services", str(context.exception))
|
||||
|
||||
def test_handle_frame_unsupported_frame_type(self):
|
||||
"""Test that unsupported frame types raise an error."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIn("Unsupported frame type", str(context.exception))
|
||||
|
||||
|
||||
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcher."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_manual_strategy(self):
|
||||
"""Test initialization with manual strategy."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
self.assertEqual(switcher.services, self.services)
|
||||
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
|
||||
self.assertEqual(switcher.strategy.services, self.services)
|
||||
|
||||
async def test_default_active_service(self):
|
||||
"""Test that the initially-active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
service.reset_counters()
|
||||
|
||||
# Send some test frames
|
||||
frames_to_send = [
|
||||
TextFrame(text="Hello 1"),
|
||||
TextFrame(text="Hello 2"),
|
||||
TextFrame(text="Hello 3"),
|
||||
]
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[TextFrame, TextFrame, TextFrame],
|
||||
)
|
||||
|
||||
# Only service1 should have processed the text frames
|
||||
# Note: The service also receives StartFrame and EndFrame, so count those too
|
||||
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
|
||||
self.assertEqual(len(text_frames), 3)
|
||||
|
||||
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service3_text_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
self.assertEqual(len(service2_text_frames), 0)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
# Verify the actual text frames processed
|
||||
for i, frame in enumerate(text_frames):
|
||||
self.assertEqual(frame.text, f"Hello {i + 1}")
|
||||
|
||||
async def test_service_switching(self):
|
||||
"""Test that after service switching the new active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
service.reset_counters()
|
||||
|
||||
# Send a test frame, a switch frame, and another test frame
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[
|
||||
TextFrame("Frame for service1"),
|
||||
ManuallySwitchServiceFrame(service=self.service2),
|
||||
TextFrame("Frame for service2"),
|
||||
],
|
||||
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
|
||||
)
|
||||
|
||||
# Verify service2 received the frame
|
||||
service1_text_frames = [
|
||||
f for f in self.service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service3_text_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
self.assertEqual(len(service1_text_frames), 1)
|
||||
self.assertEqual(len(service2_text_frames), 1)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
self.assertEqual(service1_text_frames[0].text, "Frame for service1")
|
||||
self.assertEqual(service2_text_frames[0].text, "Frame for service2")
|
||||
|
||||
async def test_service_switching_functionality(self):
|
||||
"""Test that switching between services works correctly."""
|
||||
# This test is simplified to avoid timing issues
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Test that we can programmatically switch services
|
||||
self.assertEqual(switcher.strategy.active_service, self.service1) # Initially service1
|
||||
|
||||
# Manually switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
switcher.strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify the switch worked
|
||||
self.assertNotEqual(switcher.strategy.active_service, self.service1)
|
||||
self.assertEqual(switcher.strategy.active_service, self.service2)
|
||||
self.assertNotEqual(switcher.strategy.active_service, self.service3)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
switcher.strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify the switch worked
|
||||
self.assertNotEqual(switcher.strategy.active_service, self.service1)
|
||||
self.assertNotEqual(switcher.strategy.active_service, self.service2)
|
||||
self.assertEqual(switcher.strategy.active_service, self.service3)
|
||||
|
||||
async def test_service_switching_with_empty_services_list(self):
|
||||
"""Test behavior with an empty services list."""
|
||||
# ServiceSwitcher should handle empty services gracefully, but ParallelPipeline needs at least one pipeline
|
||||
# So this test verifies that an exception is raised as expected
|
||||
with self.assertRaises(Exception) as context:
|
||||
ServiceSwitcher([], ServiceSwitcherStrategyManual)
|
||||
|
||||
self.assertIn("ParallelPipeline needs at least one argument", str(context.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user