# # Copyright (c) 2024-2026, Daily # # SPDX-License-Identifier: BSD 2-Clause License # """Unit tests for ServiceSwitcher and related components.""" import asyncio import unittest from dataclasses import dataclass from pipecat.frames.frames import ( Frame, ManuallySwitchServiceFrame, ServiceMetadataFrame, ServiceSwitcherRequestMetadataFrame, StartFrame, SystemFrame, TextFrame, ) from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.tests.utils import run_test class MockFrameProcessor(FrameProcessor): """A test frame processor that tracks which frames it has processed.""" def __init__(self, test_name: str, **kwargs): """Initialize the test processor with a name. Args: test_name: A unique name for this processor instance. **kwargs: Additional arguments passed to the parent FrameProcessor. """ super().__init__(name=test_name, **kwargs) self.test_name = test_name self.processed_frames = [] self.frame_count = 0 async def process_frame(self, frame: Frame, direction: FrameDirection): """Process an incoming frame and track it. Args: frame: The frame to process. direction: The direction of frame flow in the pipeline. """ await super().process_frame(frame, direction) self.processed_frames.append(frame) self.frame_count += 1 await self.push_frame(frame, direction) def reset_counters(self): """Reset the frame tracking counters.""" self.processed_frames = [] self.frame_count = 0 @dataclass class MockMetadataFrame(ServiceMetadataFrame): """A mock metadata frame for testing ServiceMetadataFrame handling.""" pass class MockMetadataService(FrameProcessor): """A mock service that emits ServiceMetadataFrame like STT services. Pushes MockMetadataFrame on StartFrame and ServiceSwitcherRequestMetadataFrame. """ def __init__(self, test_name: str, **kwargs): super().__init__(name=test_name, **kwargs) self.test_name = test_name self.processed_frames = [] self.metadata_push_count = 0 async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) self.processed_frames.append(frame) if isinstance(frame, StartFrame): await self.push_frame(frame, direction) await self._push_metadata() elif isinstance(frame, ServiceSwitcherRequestMetadataFrame): await self._push_metadata() await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) async def _push_metadata(self): self.metadata_push_count += 1 await self.push_frame(MockMetadataFrame(service_name=self.test_name)) def reset_counters(self): self.processed_frames = [] self.metadata_push_count = 0 @dataclass class DummySystemFrame(SystemFrame): """A dummy system frame for testing purposes.""" text: str = "" class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase): """Test cases for ServiceSwitcherStrategyManual.""" def setUp(self): """Set up test fixtures.""" self.service1 = MockFrameProcessor("service1") self.service2 = MockFrameProcessor("service2") self.service3 = MockFrameProcessor("service3") self.services = [self.service1, self.service2, self.service3] def test_init_with_services(self): """Test initialization with a list of services.""" strategy = ServiceSwitcherStrategyManual(self.services) self.assertEqual(strategy.services, self.services) self.assertEqual(strategy.active_service, self.service1) # First service should be active async def test_handle_manually_switch_service_frame(self): """Test manual service switching with ManuallySwitchServiceFrame.""" strategy = ServiceSwitcherStrategyManual(self.services) # Initially service1 should be active self.assertEqual(strategy.active_service, self.service1) self.assertNotEqual(strategy.active_service, self.service2) # Switch to service2 switch_frame = ManuallySwitchServiceFrame(service=self.service2) await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM) self.assertNotEqual(strategy.active_service, self.service1) self.assertEqual(strategy.active_service, self.service2) self.assertNotEqual(strategy.active_service, self.service3) # Switch to service3 switch_frame = ManuallySwitchServiceFrame(service=self.service3) await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM) self.assertNotEqual(strategy.active_service, self.service1) self.assertNotEqual(strategy.active_service, self.service2) self.assertEqual(strategy.active_service, self.service3) async def test_on_service_switched_event(self): """Test that on_service_switched event fires with correct arguments.""" strategy = ServiceSwitcherStrategyManual(self.services) switched_events = [] @strategy.event_handler("on_service_switched") async def on_service_switched(strategy, service): switched_events.append((strategy, service)) # Switch to service2 switch_frame = ManuallySwitchServiceFrame(service=self.service2) await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM) await asyncio.sleep(0) # Let async event task run self.assertEqual(len(switched_events), 1) self.assertIsInstance(switched_events[0][0], ServiceSwitcherStrategyManual) self.assertEqual(switched_events[0][1], self.service2) # Switch to service3 switch_frame = ManuallySwitchServiceFrame(service=self.service3) await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM) await asyncio.sleep(0) self.assertEqual(len(switched_events), 2) self.assertEqual(switched_events[1][1], self.service3) async def test_on_service_switched_event_not_fired_for_unknown_service(self): """Test that on_service_switched event does not fire for services not in the list.""" strategy = ServiceSwitcherStrategyManual(self.services) switched_events = [] @strategy.event_handler("on_service_switched") async def on_service_switched(strategy, service): switched_events.append(service) # Try switching to a service not in the list unknown_service = MockFrameProcessor("unknown") switch_frame = ManuallySwitchServiceFrame(service=unknown_service) await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM) await asyncio.sleep(0) self.assertEqual(len(switched_events), 0) self.assertEqual(strategy.active_service, self.service1) # Unchanged async def test_handle_frame_unsupported_frame_type(self): """Test that unsupported frame types raise an error.""" strategy = ServiceSwitcherStrategyManual(self.services) unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame result = await strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM) self.assertIsNone(result) class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase): """Test cases for ServiceSwitcher.""" def setUp(self): """Set up test fixtures.""" self.service1 = MockFrameProcessor("service1") self.service2 = MockFrameProcessor("service2") self.service3 = MockFrameProcessor("service3") self.services = [self.service1, self.service2, self.service3] def test_init_with_manual_strategy(self): """Test initialization with manual strategy.""" switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual) self.assertEqual(switcher.services, self.services) self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual) self.assertEqual(switcher.strategy.services, self.services) async def test_default_active_service(self): """Test that the initially-active service receives frames while others don't.""" switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual) # Reset counters for service in self.services: service.reset_counters() # Send some test frames frames_to_send = [ TextFrame(text="Hello 1"), DummySystemFrame(text="System Message 1"), TextFrame(text="Hello 2"), DummySystemFrame(text="System Message 2"), TextFrame(text="Hello 3"), ] await run_test( switcher, frames_to_send=frames_to_send, expected_down_frames=[ DummySystemFrame, DummySystemFrame, TextFrame, TextFrame, TextFrame, ], expected_up_frames=[], # Expect no error frames ) # Only service1 should have processed the text frames # Note: The service also receives StartFrame and EndFrame, so count those too text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)] self.assertEqual(len(text_frames), 3) # Only service1 should have processed the system frames system_frames = [ f for f in self.service1.processed_frames if isinstance(f, DummySystemFrame) ] self.assertEqual(len(system_frames), 2) # Check that other services don't receive text frames (they still get StartFrame/EndFrame) service2_text_frames = [ f for f in self.service2.processed_frames if isinstance(f, TextFrame) ] service3_text_frames = [ f for f in self.service3.processed_frames if isinstance(f, TextFrame) ] self.assertEqual(len(service2_text_frames), 0) self.assertEqual(len(service3_text_frames), 0) # Check that other services don't receive dummy system frames (they still get StartFrame/EndFrame) service2_system_frames = [ f for f in self.service2.processed_frames if isinstance(f, DummySystemFrame) ] service3_system_frames = [ f for f in self.service3.processed_frames if isinstance(f, DummySystemFrame) ] self.assertEqual(len(service2_system_frames), 0) self.assertEqual(len(service3_system_frames), 0) # Verify the actual text frames processed for i, frame in enumerate(text_frames): self.assertEqual(frame.text, f"Hello {i + 1}") # Verify the actual system frames processed for i, frame in enumerate(system_frames): self.assertEqual(frame.text, f"System Message {i + 1}") async def test_service_switching(self): """Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't.""" switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual) # Reset counters for service in self.services: service.reset_counters() # Send a test frame, a switch frame, and another test frame await run_test( switcher, frames_to_send=[ TextFrame("Hello 1"), ManuallySwitchServiceFrame(service=self.service2), TextFrame("Hello 2"), ], expected_down_frames=[TextFrame, TextFrame], expected_up_frames=[], # Expect no error frames ) # Verify service2 received the frame service1_text_frames = [ f for f in self.service1.processed_frames if isinstance(f, TextFrame) ] service2_text_frames = [ f for f in self.service2.processed_frames if isinstance(f, TextFrame) ] service3_text_frames = [ f for f in self.service3.processed_frames if isinstance(f, TextFrame) ] self.assertEqual(len(service1_text_frames), 1) self.assertEqual(len(service2_text_frames), 1) self.assertEqual(len(service3_text_frames), 0) self.assertEqual(service1_text_frames[0].text, "Hello 1") self.assertEqual(service2_text_frames[0].text, "Hello 2") async def test_multi_service_switcher_targeting(self): """Test that ManuallySwitchServiceFrame targets the correct ServiceSwitcher in a multi-switcher pipeline.""" # Create services for first switcher switcher1_service1 = MockFrameProcessor("switcher1_service1") switcher1_service2 = MockFrameProcessor("switcher1_service2") switcher1_services = [switcher1_service1, switcher1_service2] # Create services for second switcher switcher2_service1 = MockFrameProcessor("switcher2_service1") switcher2_service2 = MockFrameProcessor("switcher2_service2") switcher2_services = [switcher2_service1, switcher2_service2] # Create two service switchers switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual) switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual) # Create a pipeline with both switchers: switcher1 -> switcher2 pipeline = Pipeline([switcher1, switcher2]) # Reset counters for service in switcher1_services + switcher2_services: service.reset_counters() # Initially, both switchers should use their first services self.assertEqual(switcher1.strategy.active_service, switcher1_service1) self.assertEqual(switcher2.strategy.active_service, switcher2_service1) # Send frames to test the pipeline: # 1. Text frame (should go through both switchers' active services) # 2. Switch frame targeting switcher1's second service # 3. Text frame (should go through switcher1's new service and switcher2's original service) # 4. Switch frame targeting switcher2's second service # 5. Text frame (should go through switcher1's current service and switcher2's new service) await run_test( pipeline, frames_to_send=[ TextFrame("Before any switches"), ManuallySwitchServiceFrame(service=switcher1_service2), # Switch first switcher TextFrame("After switching first switcher"), ManuallySwitchServiceFrame(service=switcher2_service2), # Switch second switcher TextFrame("After switching second switcher"), ], expected_down_frames=[ TextFrame, TextFrame, TextFrame, ], expected_up_frames=[], # Expect no error frames ) # Verify the active services changed correctly self.assertEqual(switcher1.strategy.active_service, switcher1_service2) self.assertEqual(switcher2.strategy.active_service, switcher2_service2) # Verify frame distribution: # First text frame should go through switcher1_service1 and switcher2_service1 switcher1_service1_texts = [ f for f in switcher1_service1.processed_frames if isinstance(f, TextFrame) ] switcher2_service1_texts = [ f for f in switcher2_service1.processed_frames if isinstance(f, TextFrame) ] # Second text frame should go through switcher1_service2 and switcher2_service1 switcher1_service2_texts = [ f for f in switcher1_service2.processed_frames if isinstance(f, TextFrame) ] # Third text frame should go through switcher1_service2 and switcher2_service2 switcher2_service2_texts = [ f for f in switcher2_service2.processed_frames if isinstance(f, TextFrame) ] # Verify frame counts and content self.assertEqual(len(switcher1_service1_texts), 1) self.assertEqual(switcher1_service1_texts[0].text, "Before any switches") self.assertEqual(len(switcher1_service2_texts), 2) self.assertEqual(switcher1_service2_texts[0].text, "After switching first switcher") self.assertEqual(switcher1_service2_texts[1].text, "After switching second switcher") self.assertEqual(len(switcher2_service1_texts), 2) self.assertEqual(switcher2_service1_texts[0].text, "Before any switches") self.assertEqual(switcher2_service1_texts[1].text, "After switching first switcher") self.assertEqual(len(switcher2_service2_texts), 1) self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher") class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase): """Test cases for ServiceMetadataFrame handling in ServiceSwitcher.""" def setUp(self): """Set up test fixtures with mock metadata services.""" self.service1 = MockMetadataService("service1") self.service2 = MockMetadataService("service2") self.services = [self.service1, self.service2] async def test_only_active_service_metadata_at_startup(self): """Test that only the active service's metadata leaves the ServiceSwitcher at startup.""" switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual) # Run the pipeline (StartFrame triggers metadata emission) output_frames = [] async def capture_frame(frame: Frame): output_frames.append(frame) await run_test( switcher, frames_to_send=[TextFrame(text="test")], expected_down_frames=[MockMetadataFrame, TextFrame], expected_up_frames=[], ) # Both services push metadata internally on StartFrame, but only the # active service's metadata passes through the filter self.assertEqual(self.service1.metadata_push_count, 1) # StartFrame (passes filter) self.assertEqual(self.service2.metadata_push_count, 1) # StartFrame (blocked by filter) async def test_metadata_emitted_on_service_switch(self): """Test that switching services triggers metadata emission from the new active service.""" switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual) # Reset counters after startup self.service1.reset_counters() self.service2.reset_counters() await run_test( switcher, frames_to_send=[ TextFrame(text="before switch"), ManuallySwitchServiceFrame(service=self.service2), TextFrame(text="after switch"), ], expected_down_frames=[ MockMetadataFrame, # From startup (service1) TextFrame, MockMetadataFrame, # From service2 after switch TextFrame, ], expected_up_frames=[], ) # service2 should have received ServiceSwitcherRequestMetadataFrame after becoming active request_frames = [ f for f in self.service2.processed_frames if isinstance(f, ServiceSwitcherRequestMetadataFrame) ] self.assertEqual(len(request_frames), 1) async def test_inactive_service_metadata_blocked(self): """Test that metadata from inactive services is blocked.""" switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual) # Run and collect output frames await run_test( switcher, frames_to_send=[TextFrame(text="test")], expected_down_frames=[MockMetadataFrame, TextFrame], expected_up_frames=[], ) # service2 pushed metadata on StartFrame, but it should have been blocked self.assertGreaterEqual(self.service2.metadata_push_count, 1) # Only one MockMetadataFrame should have left (from service1) if __name__ == "__main__": unittest.main()