Update TestServiceSwitcher to exercise targeting system frames only to the active service
This commit is contained in:
@@ -7,10 +7,12 @@
|
||||
"""Unit tests for ServiceSwitcher and related components."""
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -52,6 +54,13 @@ class MockFrameProcessor(FrameProcessor):
|
||||
self.frame_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySystemFrame(SystemFrame):
|
||||
"""A dummy system frame for testing purposes."""
|
||||
|
||||
text: str = ""
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
|
||||
@@ -140,14 +149,22 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
# Send some test frames
|
||||
frames_to_send = [
|
||||
TextFrame(text="Hello 1"),
|
||||
DummySystemFrame(text="System Message 1"),
|
||||
TextFrame(text="Hello 2"),
|
||||
DummySystemFrame(text="System Message 2"),
|
||||
TextFrame(text="Hello 3"),
|
||||
]
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[TextFrame, TextFrame, TextFrame],
|
||||
expected_down_frames=[
|
||||
DummySystemFrame,
|
||||
DummySystemFrame,
|
||||
TextFrame,
|
||||
TextFrame,
|
||||
TextFrame,
|
||||
],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
@@ -156,7 +173,13 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
|
||||
self.assertEqual(len(text_frames), 3)
|
||||
|
||||
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
|
||||
# Only service1 should have processed the system frames
|
||||
system_frames = [
|
||||
f for f in self.service1.processed_frames if isinstance(f, DummySystemFrame)
|
||||
]
|
||||
self.assertEqual(len(system_frames), 2)
|
||||
|
||||
# Check that other services don't receive text frames (they still get StartFrame/EndFrame)
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
@@ -166,10 +189,24 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(service2_text_frames), 0)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
# Check that other services don't receive dummy system frames (they still get StartFrame/EndFrame)
|
||||
service2_system_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, DummySystemFrame)
|
||||
]
|
||||
service3_system_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, DummySystemFrame)
|
||||
]
|
||||
self.assertEqual(len(service2_system_frames), 0)
|
||||
self.assertEqual(len(service3_system_frames), 0)
|
||||
|
||||
# Verify the actual text frames processed
|
||||
for i, frame in enumerate(text_frames):
|
||||
self.assertEqual(frame.text, f"Hello {i + 1}")
|
||||
|
||||
# Verify the actual system frames processed
|
||||
for i, frame in enumerate(system_frames):
|
||||
self.assertEqual(frame.text, f"System Message {i + 1}")
|
||||
|
||||
async def test_service_switching(self):
|
||||
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
Reference in New Issue
Block a user