Simplify ServiceSwitcher with closure-based filters
- Make ServiceSwitcherStrategy inherit from BaseObject with properties for services and active_service, and move initial service selection into the base class - Add on_service_switched event to ServiceSwitcherStrategy - handle_frame now returns the switched-to service (or None), allowing ServiceSwitcher to swallow ManuallySwitchServiceFrame on switch and request metadata from the new active service - Override push_frame to suppress RequestMetadataFrame and ServiceMetadataFrame from inactive services - Remove ServiceSwitcherFilter and ServiceSwitcherFilterFrame in favor of plain FunctionFilter instances with closures that check the strategy's active service directly - FunctionFilter: add FilterType alias - FunctionFilter: when direction is None, frames in both directions are filtered instead of just one - Add docstrings to ServiceSwitcher and its components
This commit is contained in:
committed by
Mark Backman
parent
5e66702cf5
commit
2a572aedba
@@ -44,7 +44,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
return self.services
|
||||
|
||||
@property
|
||||
def active_llm(self) -> Optional[LLMService]:
|
||||
def active_llm(self) -> LLMService:
|
||||
"""Get the currently active LLM.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -6,11 +6,10 @@
|
||||
|
||||
"""Service switcher for switching between different services at runtime, with different switching strategies."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
RequestMetadataFrame,
|
||||
@@ -20,14 +19,25 @@ from pipecat.frames.frames import (
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
class ServiceSwitcherStrategy:
|
||||
class ServiceSwitcherStrategy(BaseObject):
|
||||
"""Base class for service switching strategies.
|
||||
|
||||
Note:
|
||||
Strategy classes are instantiated internally by ServiceSwitcher.
|
||||
Developers should pass the strategy class (not an instance) to ServiceSwitcher.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_service_switched: Called when the active service changes.
|
||||
|
||||
Example::
|
||||
|
||||
@strategy.event_handler("on_service_switched")
|
||||
async def on_service_switched(strategy, service):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor]):
|
||||
@@ -39,20 +49,42 @@ class ServiceSwitcherStrategy:
|
||||
Args:
|
||||
services: List of frame processors to switch between.
|
||||
"""
|
||||
self.services = services
|
||||
self.active_service: Optional[FrameProcessor] = None
|
||||
super().__init__()
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
if len(services) == 0:
|
||||
raise Exception(f"ServiceSwitcherStrategy needs at least one service")
|
||||
|
||||
self._services = services
|
||||
self._active_service = services[0]
|
||||
|
||||
self._register_event_handler("on_service_switched")
|
||||
|
||||
@property
|
||||
def services(self) -> List[FrameProcessor]:
|
||||
"""Return the list of available services."""
|
||||
return self._services
|
||||
|
||||
@property
|
||||
def active_service(self) -> FrameProcessor:
|
||||
"""Return the currently active service."""
|
||||
return self._active_service
|
||||
|
||||
@abstractmethod
|
||||
async def handle_frame(
|
||||
self, frame: ServiceSwitcherFrame, direction: FrameDirection
|
||||
) -> Optional[FrameProcessor]:
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
This method can be overridden by subclasses to implement specific logic
|
||||
for handling frames that control service switching.
|
||||
Subclasses implement this to decide whether a switch should occur.
|
||||
|
||||
Args:
|
||||
frame: The frame to handle.
|
||||
direction: The direction of the frame (upstream or downstream).
|
||||
|
||||
Returns:
|
||||
The newly active service if a switch occurred, or None otherwise.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
pass
|
||||
|
||||
|
||||
class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
@@ -69,31 +101,24 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor]):
|
||||
"""Initialize the manual service switcher strategy with a list of services.
|
||||
|
||||
Note:
|
||||
This is called internally by ServiceSwitcher. Do not instantiate directly.
|
||||
|
||||
Args:
|
||||
services: List of frame processors to switch between.
|
||||
"""
|
||||
super().__init__(services)
|
||||
self.active_service = services[0] if services else None
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
async def handle_frame(
|
||||
self, frame: ServiceSwitcherFrame, direction: FrameDirection
|
||||
) -> Optional[FrameProcessor]:
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
Args:
|
||||
frame: The frame to handle.
|
||||
direction: The direction of the frame (upstream or downstream).
|
||||
|
||||
Returns:
|
||||
The newly active service if a switch occurred, or None otherwise.
|
||||
"""
|
||||
if isinstance(frame, ManuallySwitchServiceFrame):
|
||||
self._set_active_if_available(frame.service)
|
||||
else:
|
||||
raise ValueError(f"Unsupported frame type: {type(frame)}")
|
||||
return await self._set_active_if_available(frame.service)
|
||||
|
||||
def _set_active_if_available(self, service: FrameProcessor):
|
||||
return None
|
||||
|
||||
async def _set_active_if_available(self, service: FrameProcessor) -> Optional[FrameProcessor]:
|
||||
"""Set the active service to the given one, if it is in the list of available services.
|
||||
|
||||
If it's not in the list, the request is ignored, as it may have been
|
||||
@@ -101,16 +126,35 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
|
||||
Args:
|
||||
service: The service to set as active.
|
||||
|
||||
Returns:
|
||||
The newly active service, or None if the service was not found.
|
||||
"""
|
||||
if service in self.services:
|
||||
self.active_service = service
|
||||
self._active_service = service
|
||||
await self._call_event_handler("on_service_switched", service)
|
||||
return service
|
||||
return None
|
||||
|
||||
|
||||
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
|
||||
|
||||
|
||||
class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
"""A pipeline that switches between different services at runtime."""
|
||||
"""Parallel pipeline that routes frames to one active service at a time.
|
||||
|
||||
Wraps each service in a pair of filters that gate frame flow based on
|
||||
which service is currently active. Switching is controlled by
|
||||
`ServiceSwitcherFrame` frames and delegated to a pluggable
|
||||
`ServiceSwitcherStrategy`.
|
||||
|
||||
Example::
|
||||
|
||||
switcher = ServiceSwitcher(
|
||||
services=[stt_1, stt_2],
|
||||
strategy_type=ServiceSwitcherStrategyManual,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor], strategy_type: Type[StrategyType]):
|
||||
"""Initialize the service switcher with a list of services and a switching strategy.
|
||||
@@ -119,101 +163,20 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
services: List of frame processors to switch between.
|
||||
strategy_type: The strategy class to use for switching between services.
|
||||
"""
|
||||
strategy = strategy_type(services)
|
||||
super().__init__(*self._make_pipeline_definitions(services, strategy))
|
||||
self.services = services
|
||||
self.strategy = strategy
|
||||
_strategy = strategy_type(services)
|
||||
super().__init__(*self._make_pipeline_definitions(services, _strategy))
|
||||
self._services = services
|
||||
self._strategy = _strategy
|
||||
|
||||
class ServiceSwitcherFilter(FunctionFilter):
|
||||
"""An internal filter that gates frame flow based on active service.
|
||||
@property
|
||||
def strategy(self) -> StrategyType:
|
||||
"""Return the active switching strategy."""
|
||||
return self._strategy
|
||||
|
||||
Two filters "sandwich" each service, allowing frames through only
|
||||
when the wrapped service is active. The pipeline layout is::
|
||||
|
||||
DownstreamFilter → Service → UpstreamFilter
|
||||
|
||||
The filter names refer to which *direction* of frame flow they
|
||||
filter, not their physical position: the downstream filter sits
|
||||
*before* the service (filtering frames flowing into it) and the
|
||||
upstream filter sits *after* it (filtering frames flowing back out).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapped_service: FrameProcessor,
|
||||
active_service: FrameProcessor,
|
||||
direction: FrameDirection,
|
||||
):
|
||||
"""Initialize the service switcher filter with a strategy and direction.
|
||||
|
||||
Args:
|
||||
wrapped_service: The service that this filter wraps.
|
||||
active_service: The currently active service.
|
||||
direction: The direction of frame flow this filter gates
|
||||
(DOWNSTREAM for the filter before the service,
|
||||
UPSTREAM for the filter after it).
|
||||
"""
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
|
||||
async def filter(_: Frame) -> bool:
|
||||
return self._wrapped_service == self._active_service
|
||||
|
||||
super().__init__(filter, direction, filter_system_frames=True)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process a frame through the filter, handling special internal filter-updating frames."""
|
||||
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
|
||||
old_active = self._active_service
|
||||
self._active_service = frame.active_service
|
||||
# Two ServiceSwitcherFilters "sandwich" a service. The
|
||||
# frame enters via the downstream filter first. Push it
|
||||
# through so the upstream filter also updates its state.
|
||||
if direction == self._direction:
|
||||
await self.push_frame(frame, direction)
|
||||
# This is the upstream filter (the second to update). At
|
||||
# this point both filters know the new active service, so
|
||||
# it's safe to request metadata — the resulting
|
||||
# ServiceMetadataFrame will pass both filters on its way
|
||||
# out. Only do this for the newly active service's sandwich.
|
||||
elif (
|
||||
self._direction == FrameDirection.UPSTREAM
|
||||
and self._wrapped_service == frame.active_service
|
||||
and old_active != self._wrapped_service
|
||||
):
|
||||
await self.push_frame(RequestMetadataFrame(), FrameDirection.UPSTREAM)
|
||||
return
|
||||
|
||||
# RequestMetadataFrame is pushed upstream by the upstream filter
|
||||
# (above) and consumed by the service. Guard against services
|
||||
# that don't consume it: only forward in the filter's own
|
||||
# direction (so it can reach the service) and only for the
|
||||
# active service. Block in all other cases to prevent it from
|
||||
# escaping the sandwich.
|
||||
if isinstance(frame, RequestMetadataFrame):
|
||||
if direction == self._direction and self._wrapped_service == self._active_service:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# Block ServiceMetadataFrame from inactive services.
|
||||
if isinstance(frame, ServiceMetadataFrame):
|
||||
if self._wrapped_service != self._active_service:
|
||||
return
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFilterFrame(ControlFrame):
|
||||
"""An internal frame used to update filter state on service switch.
|
||||
|
||||
Sent when a service switch occurs to update the active service in
|
||||
the sandwich filters and trigger metadata emission from the newly
|
||||
active service.
|
||||
"""
|
||||
|
||||
active_service: FrameProcessor
|
||||
@property
|
||||
def services(self) -> List[FrameProcessor]:
|
||||
"""Return the list of available services."""
|
||||
return self._services
|
||||
|
||||
@staticmethod
|
||||
def _make_pipeline_definitions(
|
||||
@@ -228,21 +191,53 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
def _make_pipeline_definition(
|
||||
service: FrameProcessor, strategy: ServiceSwitcherStrategy
|
||||
) -> Any:
|
||||
# Layout: DownstreamFilter → Service → UpstreamFilter
|
||||
async def filter(_: Frame) -> bool:
|
||||
return service == strategy.active_service
|
||||
|
||||
# Layout: Filter → Service → Filter
|
||||
#
|
||||
# filter_system_frames: we want to run filter functions also on system
|
||||
# frames.
|
||||
#
|
||||
# enable_direct_mode: filter functions are quick so we don't need
|
||||
# additional tasks.
|
||||
return [
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
FunctionFilter(
|
||||
filter=filter,
|
||||
direction=FrameDirection.DOWNSTREAM,
|
||||
filter_system_frames=True,
|
||||
enable_direct_mode=True,
|
||||
),
|
||||
service,
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
FunctionFilter(
|
||||
filter=filter,
|
||||
direction=FrameDirection.UPSTREAM,
|
||||
filter_system_frames=True,
|
||||
enable_direct_mode=True,
|
||||
),
|
||||
]
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame out of the service switcher.
|
||||
|
||||
Suppresses `RequestMetadataFrame` (internal to the switcher) and
|
||||
`ServiceMetadataFrame` from inactive services so only the active
|
||||
service's metadata reaches downstream processors. One case this happens
|
||||
is with `StartFrame` since all the filters let it pass, and `StartFrame`
|
||||
causes the service to generate `ServiceMetadataFrame`.
|
||||
|
||||
"""
|
||||
# Don't let RequestMetadataFrame out.
|
||||
if isinstance(frame, RequestMetadataFrame):
|
||||
return
|
||||
|
||||
# Only let metadata from the active service escape.
|
||||
if isinstance(frame, ServiceMetadataFrame):
|
||||
if frame.service_name != self.strategy.active_service.name:
|
||||
return
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process a frame, handling frames which affect service switching.
|
||||
|
||||
@@ -250,11 +245,16 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
frame: The frame to process.
|
||||
direction: The direction of the frame (upstream or downstream).
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, ServiceSwitcherFrame):
|
||||
self.strategy.handle_frame(frame, direction)
|
||||
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
|
||||
active_service=self.strategy.active_service
|
||||
)
|
||||
await super().process_frame(service_switcher_filter_frame, direction)
|
||||
service = await self.strategy.handle_frame(frame, direction)
|
||||
|
||||
# If we don't switch to a new service we need to keep processing the
|
||||
# frame. If we switched, we just swallow the frame.
|
||||
if not service:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we switched to a new service, request its metadata.
|
||||
if service:
|
||||
await service.queue_frame(RequestMetadataFrame())
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -10,11 +10,13 @@ This module provides a processor that filters frames based on a custom function,
|
||||
allowing for flexible frame filtering logic in processing pipelines.
|
||||
"""
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
FilterType = Callable[[Frame], Awaitable[bool]]
|
||||
|
||||
|
||||
class FunctionFilter(FrameProcessor):
|
||||
"""A frame processor that filters frames using a custom function.
|
||||
@@ -26,9 +28,10 @@ class FunctionFilter(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Callable[[Frame], Awaitable[bool]],
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
filter: FilterType,
|
||||
direction: Optional[FrameDirection] = FrameDirection.DOWNSTREAM,
|
||||
filter_system_frames: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the function filter.
|
||||
|
||||
@@ -36,10 +39,13 @@ class FunctionFilter(FrameProcessor):
|
||||
filter: An async function that takes a Frame and returns True if the
|
||||
frame should pass through, False otherwise.
|
||||
direction: The direction to apply filtering. Only frames moving in
|
||||
this direction will be filtered. Defaults to DOWNSTREAM.
|
||||
this direction will be filtered; frames in the other direction
|
||||
pass through unfiltered. If None, frames in both directions
|
||||
are filtered. Defaults to DOWNSTREAM.
|
||||
filter_system_frames: Whether to filter system frames. Defaults to False.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._filter = filter
|
||||
self._direction = direction
|
||||
self._filter_system_frames = filter_system_frames
|
||||
@@ -51,7 +57,7 @@ class FunctionFilter(FrameProcessor):
|
||||
def _should_passthrough_frame(self, frame, direction):
|
||||
"""Check if a frame should pass through without filtering."""
|
||||
# Always passthrough frames in the wrong direction
|
||||
if direction != self._direction:
|
||||
if self._direction and direction != self._direction:
|
||||
return True
|
||||
|
||||
# Always passthrough lifecycle frames
|
||||
|
||||
@@ -14,10 +14,12 @@ from pipecat.frames.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.filters.frame_filter import FrameFilter
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.filters.wake_check_filter import WakeCheckFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
@@ -93,6 +95,98 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_no_direction_filters_both_directions(self):
|
||||
"""When direction is None, frames in both directions are filtered."""
|
||||
|
||||
class UpstreamPusher(FrameProcessor):
|
||||
"""Pushes a TextFrame upstream when it receives a system frame."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.push_frame(TextFrame(text="upstream"), FrameDirection.UPSTREAM)
|
||||
|
||||
async def block_text(frame: Frame):
|
||||
return not isinstance(frame, TextFrame)
|
||||
|
||||
# direction=None: filter applies in both directions. The downstream
|
||||
# TextFrame is blocked and the upstream TextFrame pushed by
|
||||
# UpstreamPusher is also blocked.
|
||||
filter = FunctionFilter(filter=block_text, direction=None)
|
||||
pipeline = Pipeline([filter, UpstreamPusher()])
|
||||
frames_to_send = [
|
||||
TextFrame(text="Hello!"),
|
||||
UserStartedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [UserStartedSpeakingFrame]
|
||||
expected_up_frames = []
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
async def test_downstream_direction_passes_upstream(self):
|
||||
"""When direction is DOWNSTREAM, upstream frames pass through unfiltered."""
|
||||
|
||||
class UpstreamPusher(FrameProcessor):
|
||||
"""Pushes a TextFrame upstream when it receives a system frame."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.push_frame(TextFrame(text="upstream"), FrameDirection.UPSTREAM)
|
||||
|
||||
async def block_text(frame: Frame):
|
||||
return not isinstance(frame, TextFrame)
|
||||
|
||||
# direction=DOWNSTREAM: filter only applies downstream, so the
|
||||
# upstream TextFrame pushed by UpstreamPusher passes through.
|
||||
filter = FunctionFilter(filter=block_text)
|
||||
pipeline = Pipeline([filter, UpstreamPusher()])
|
||||
frames_to_send = [UserStartedSpeakingFrame()]
|
||||
expected_down_frames = [UserStartedSpeakingFrame]
|
||||
expected_up_frames = [TextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
async def test_upstream_direction_passes_downstream(self):
|
||||
"""When direction is UPSTREAM, downstream frames pass through unfiltered."""
|
||||
|
||||
class UpstreamPusher(FrameProcessor):
|
||||
"""Pushes a TextFrame upstream when it receives a system frame."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.push_frame(TextFrame(text="upstream"), FrameDirection.UPSTREAM)
|
||||
|
||||
async def block_text(frame: Frame):
|
||||
return not isinstance(frame, TextFrame)
|
||||
|
||||
# direction=UPSTREAM: filter only applies upstream, so the
|
||||
# downstream TextFrame passes through but the upstream TextFrame
|
||||
# pushed by UpstreamPusher is blocked.
|
||||
filter = FunctionFilter(filter=block_text, direction=FrameDirection.UPSTREAM)
|
||||
pipeline = Pipeline([filter, UpstreamPusher()])
|
||||
frames_to_send = [TextFrame(text="Hello!"), UserStartedSpeakingFrame()]
|
||||
expected_down_frames = [UserStartedSpeakingFrame, TextFrame]
|
||||
expected_up_frames = []
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
|
||||
class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_no_wake_word(self):
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
"""Unit tests for ServiceSwitcher and related components."""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -122,14 +123,7 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(strategy.services, self.services)
|
||||
self.assertEqual(strategy.active_service, self.service1) # First service should be active
|
||||
|
||||
def test_init_with_empty_services(self):
|
||||
"""Test initialization with an empty list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual([])
|
||||
|
||||
self.assertEqual(strategy.services, [])
|
||||
self.assertIsNone(strategy.active_service)
|
||||
|
||||
def test_handle_manually_switch_service_frame(self):
|
||||
async def test_handle_manually_switch_service_frame(self):
|
||||
"""Test manual service switching with ManuallySwitchServiceFrame."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
@@ -139,7 +133,7 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
@@ -147,21 +141,66 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
def test_handle_frame_unsupported_frame_type(self):
|
||||
async def test_on_service_switched_event(self):
|
||||
"""Test that on_service_switched event fires with correct arguments."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
switched_events = []
|
||||
|
||||
@strategy.event_handler("on_service_switched")
|
||||
async def on_service_switched(strategy, service):
|
||||
switched_events.append((strategy, service))
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await asyncio.sleep(0) # Let async event task run
|
||||
|
||||
self.assertEqual(len(switched_events), 1)
|
||||
self.assertIsInstance(switched_events[0][0], ServiceSwitcherStrategyManual)
|
||||
self.assertEqual(switched_events[0][1], self.service2)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(len(switched_events), 2)
|
||||
self.assertEqual(switched_events[1][1], self.service3)
|
||||
|
||||
async def test_on_service_switched_event_not_fired_for_unknown_service(self):
|
||||
"""Test that on_service_switched event does not fire for services not in the list."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
switched_events = []
|
||||
|
||||
@strategy.event_handler("on_service_switched")
|
||||
async def on_service_switched(strategy, service):
|
||||
switched_events.append(service)
|
||||
|
||||
# Try switching to a service not in the list
|
||||
unknown_service = MockFrameProcessor("unknown")
|
||||
switch_frame = ManuallySwitchServiceFrame(service=unknown_service)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(len(switched_events), 0)
|
||||
self.assertEqual(strategy.active_service, self.service1) # Unchanged
|
||||
|
||||
async def test_handle_frame_unsupported_frame_type(self):
|
||||
"""Test that unsupported frame types raise an error."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
result = await strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIn("Unsupported frame type", str(context.exception))
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -267,7 +306,7 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
ManuallySwitchServiceFrame(service=self.service2),
|
||||
TextFrame("Hello 2"),
|
||||
],
|
||||
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
|
||||
expected_down_frames=[TextFrame, TextFrame],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
@@ -333,9 +372,7 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
],
|
||||
expected_down_frames=[
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
@@ -429,9 +466,8 @@ class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=[
|
||||
MockMetadataFrame, # From startup (service1)
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
MockMetadataFrame, # From service2 after switch
|
||||
TextFrame,
|
||||
],
|
||||
expected_up_frames=[],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user