Simplify ServiceSwitcher with closure-based filters

- Make ServiceSwitcherStrategy inherit from BaseObject with properties
  for services and active_service, and move initial service selection
  into the base class
- Add on_service_switched event to ServiceSwitcherStrategy
- handle_frame now returns the switched-to service (or None), allowing
  ServiceSwitcher to swallow ManuallySwitchServiceFrame on switch and
  request metadata from the new active service
- Override push_frame to suppress RequestMetadataFrame and
  ServiceMetadataFrame from inactive services
- Remove ServiceSwitcherFilter and ServiceSwitcherFilterFrame in favor
  of plain FunctionFilter instances with closures that check the
  strategy's active service directly
- FunctionFilter: add FilterType alias
- FunctionFilter: when direction is None, frames in both directions
  are filtered instead of just one
- Add docstrings to ServiceSwitcher and its components
This commit is contained in:
Aleix Conchillo Flaqué
2026-02-06 19:18:18 -08:00
committed by Mark Backman
parent 5e66702cf5
commit 2a572aedba
5 changed files with 297 additions and 161 deletions

View File

@@ -44,7 +44,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
return self.services
@property
def active_llm(self) -> Optional[LLMService]:
def active_llm(self) -> LLMService:
"""Get the currently active LLM.
Returns:

View File

@@ -6,11 +6,10 @@
"""Service switcher for switching between different services at runtime, with different switching strategies."""
from dataclasses import dataclass
from abc import abstractmethod
from typing import Any, Generic, List, Optional, Type, TypeVar
from pipecat.frames.frames import (
ControlFrame,
Frame,
ManuallySwitchServiceFrame,
RequestMetadataFrame,
@@ -20,14 +19,25 @@ from pipecat.frames.frames import (
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.base_object import BaseObject
class ServiceSwitcherStrategy:
class ServiceSwitcherStrategy(BaseObject):
"""Base class for service switching strategies.
Note:
Strategy classes are instantiated internally by ServiceSwitcher.
Developers should pass the strategy class (not an instance) to ServiceSwitcher.
Event handlers available:
- on_service_switched: Called when the active service changes.
Example::
@strategy.event_handler("on_service_switched")
async def on_service_switched(strategy, service):
...
"""
def __init__(self, services: List[FrameProcessor]):
@@ -39,20 +49,42 @@ class ServiceSwitcherStrategy:
Args:
services: List of frame processors to switch between.
"""
self.services = services
self.active_service: Optional[FrameProcessor] = None
super().__init__()
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
if len(services) == 0:
raise Exception(f"ServiceSwitcherStrategy needs at least one service")
self._services = services
self._active_service = services[0]
self._register_event_handler("on_service_switched")
@property
def services(self) -> List[FrameProcessor]:
"""Return the list of available services."""
return self._services
@property
def active_service(self) -> FrameProcessor:
"""Return the currently active service."""
return self._active_service
@abstractmethod
async def handle_frame(
self, frame: ServiceSwitcherFrame, direction: FrameDirection
) -> Optional[FrameProcessor]:
"""Handle a frame that controls service switching.
This method can be overridden by subclasses to implement specific logic
for handling frames that control service switching.
Subclasses implement this to decide whether a switch should occur.
Args:
frame: The frame to handle.
direction: The direction of the frame (upstream or downstream).
Returns:
The newly active service if a switch occurred, or None otherwise.
"""
raise NotImplementedError("Subclasses must implement this method.")
pass
class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
@@ -69,31 +101,24 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
)
"""
def __init__(self, services: List[FrameProcessor]):
"""Initialize the manual service switcher strategy with a list of services.
Note:
This is called internally by ServiceSwitcher. Do not instantiate directly.
Args:
services: List of frame processors to switch between.
"""
super().__init__(services)
self.active_service = services[0] if services else None
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
async def handle_frame(
self, frame: ServiceSwitcherFrame, direction: FrameDirection
) -> Optional[FrameProcessor]:
"""Handle a frame that controls service switching.
Args:
frame: The frame to handle.
direction: The direction of the frame (upstream or downstream).
Returns:
The newly active service if a switch occurred, or None otherwise.
"""
if isinstance(frame, ManuallySwitchServiceFrame):
self._set_active_if_available(frame.service)
else:
raise ValueError(f"Unsupported frame type: {type(frame)}")
return await self._set_active_if_available(frame.service)
def _set_active_if_available(self, service: FrameProcessor):
return None
async def _set_active_if_available(self, service: FrameProcessor) -> Optional[FrameProcessor]:
"""Set the active service to the given one, if it is in the list of available services.
If it's not in the list, the request is ignored, as it may have been
@@ -101,16 +126,35 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
Args:
service: The service to set as active.
Returns:
The newly active service, or None if the service was not found.
"""
if service in self.services:
self.active_service = service
self._active_service = service
await self._call_event_handler("on_service_switched", service)
return service
return None
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
"""A pipeline that switches between different services at runtime."""
"""Parallel pipeline that routes frames to one active service at a time.
Wraps each service in a pair of filters that gate frame flow based on
which service is currently active. Switching is controlled by
`ServiceSwitcherFrame` frames and delegated to a pluggable
`ServiceSwitcherStrategy`.
Example::
switcher = ServiceSwitcher(
services=[stt_1, stt_2],
strategy_type=ServiceSwitcherStrategyManual,
)
"""
def __init__(self, services: List[FrameProcessor], strategy_type: Type[StrategyType]):
"""Initialize the service switcher with a list of services and a switching strategy.
@@ -119,101 +163,20 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
services: List of frame processors to switch between.
strategy_type: The strategy class to use for switching between services.
"""
strategy = strategy_type(services)
super().__init__(*self._make_pipeline_definitions(services, strategy))
self.services = services
self.strategy = strategy
_strategy = strategy_type(services)
super().__init__(*self._make_pipeline_definitions(services, _strategy))
self._services = services
self._strategy = _strategy
class ServiceSwitcherFilter(FunctionFilter):
"""An internal filter that gates frame flow based on active service.
@property
def strategy(self) -> StrategyType:
"""Return the active switching strategy."""
return self._strategy
Two filters "sandwich" each service, allowing frames through only
when the wrapped service is active. The pipeline layout is::
DownstreamFilter → Service → UpstreamFilter
The filter names refer to which *direction* of frame flow they
filter, not their physical position: the downstream filter sits
*before* the service (filtering frames flowing into it) and the
upstream filter sits *after* it (filtering frames flowing back out).
"""
def __init__(
self,
wrapped_service: FrameProcessor,
active_service: FrameProcessor,
direction: FrameDirection,
):
"""Initialize the service switcher filter with a strategy and direction.
Args:
wrapped_service: The service that this filter wraps.
active_service: The currently active service.
direction: The direction of frame flow this filter gates
(DOWNSTREAM for the filter before the service,
UPSTREAM for the filter after it).
"""
self._wrapped_service = wrapped_service
self._active_service = active_service
async def filter(_: Frame) -> bool:
return self._wrapped_service == self._active_service
super().__init__(filter, direction, filter_system_frames=True)
async def process_frame(self, frame, direction):
"""Process a frame through the filter, handling special internal filter-updating frames."""
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
old_active = self._active_service
self._active_service = frame.active_service
# Two ServiceSwitcherFilters "sandwich" a service. The
# frame enters via the downstream filter first. Push it
# through so the upstream filter also updates its state.
if direction == self._direction:
await self.push_frame(frame, direction)
# This is the upstream filter (the second to update). At
# this point both filters know the new active service, so
# it's safe to request metadata — the resulting
# ServiceMetadataFrame will pass both filters on its way
# out. Only do this for the newly active service's sandwich.
elif (
self._direction == FrameDirection.UPSTREAM
and self._wrapped_service == frame.active_service
and old_active != self._wrapped_service
):
await self.push_frame(RequestMetadataFrame(), FrameDirection.UPSTREAM)
return
# RequestMetadataFrame is pushed upstream by the upstream filter
# (above) and consumed by the service. Guard against services
# that don't consume it: only forward in the filter's own
# direction (so it can reach the service) and only for the
# active service. Block in all other cases to prevent it from
# escaping the sandwich.
if isinstance(frame, RequestMetadataFrame):
if direction == self._direction and self._wrapped_service == self._active_service:
await self.push_frame(frame, direction)
return
# Block ServiceMetadataFrame from inactive services.
if isinstance(frame, ServiceMetadataFrame):
if self._wrapped_service != self._active_service:
return
await self.push_frame(frame, direction)
return
await super().process_frame(frame, direction)
@dataclass
class ServiceSwitcherFilterFrame(ControlFrame):
"""An internal frame used to update filter state on service switch.
Sent when a service switch occurs to update the active service in
the sandwich filters and trigger metadata emission from the newly
active service.
"""
active_service: FrameProcessor
@property
def services(self) -> List[FrameProcessor]:
"""Return the list of available services."""
return self._services
@staticmethod
def _make_pipeline_definitions(
@@ -228,21 +191,53 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
def _make_pipeline_definition(
service: FrameProcessor, strategy: ServiceSwitcherStrategy
) -> Any:
# Layout: DownstreamFilter → Service → UpstreamFilter
async def filter(_: Frame) -> bool:
return service == strategy.active_service
# Layout: Filter → Service → Filter
#
# filter_system_frames: we want to run filter functions also on system
# frames.
#
# enable_direct_mode: filter functions are quick so we don't need
# additional tasks.
return [
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
FunctionFilter(
filter=filter,
direction=FrameDirection.DOWNSTREAM,
filter_system_frames=True,
enable_direct_mode=True,
),
service,
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
FunctionFilter(
filter=filter,
direction=FrameDirection.UPSTREAM,
filter_system_frames=True,
enable_direct_mode=True,
),
]
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame out of the service switcher.
Suppresses `RequestMetadataFrame` (internal to the switcher) and
`ServiceMetadataFrame` from inactive services so only the active
service's metadata reaches downstream processors. One case this happens
is with `StartFrame` since all the filters let it pass, and `StartFrame`
causes the service to generate `ServiceMetadataFrame`.
"""
# Don't let RequestMetadataFrame out.
if isinstance(frame, RequestMetadataFrame):
return
# Only let metadata from the active service escape.
if isinstance(frame, ServiceMetadataFrame):
if frame.service_name != self.strategy.active_service.name:
return
await super().push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame, handling frames which affect service switching.
@@ -250,11 +245,16 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
frame: The frame to process.
direction: The direction of the frame (upstream or downstream).
"""
await super().process_frame(frame, direction)
if isinstance(frame, ServiceSwitcherFrame):
self.strategy.handle_frame(frame, direction)
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
active_service=self.strategy.active_service
)
await super().process_frame(service_switcher_filter_frame, direction)
service = await self.strategy.handle_frame(frame, direction)
# If we don't switch to a new service we need to keep processing the
# frame. If we switched, we just swallow the frame.
if not service:
await super().process_frame(frame, direction)
# If we switched to a new service, request its metadata.
if service:
await service.queue_frame(RequestMetadataFrame())
else:
await super().process_frame(frame, direction)

View File

@@ -10,11 +10,13 @@ This module provides a processor that filters frames based on a custom function,
allowing for flexible frame filtering logic in processing pipelines.
"""
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Optional
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
FilterType = Callable[[Frame], Awaitable[bool]]
class FunctionFilter(FrameProcessor):
"""A frame processor that filters frames using a custom function.
@@ -26,9 +28,10 @@ class FunctionFilter(FrameProcessor):
def __init__(
self,
filter: Callable[[Frame], Awaitable[bool]],
direction: FrameDirection = FrameDirection.DOWNSTREAM,
filter: FilterType,
direction: Optional[FrameDirection] = FrameDirection.DOWNSTREAM,
filter_system_frames: bool = False,
**kwargs,
):
"""Initialize the function filter.
@@ -36,10 +39,13 @@ class FunctionFilter(FrameProcessor):
filter: An async function that takes a Frame and returns True if the
frame should pass through, False otherwise.
direction: The direction to apply filtering. Only frames moving in
this direction will be filtered. Defaults to DOWNSTREAM.
this direction will be filtered; frames in the other direction
pass through unfiltered. If None, frames in both directions
are filtered. Defaults to DOWNSTREAM.
filter_system_frames: Whether to filter system frames. Defaults to False.
**kwargs: Additional arguments passed to parent class.
"""
super().__init__()
super().__init__(**kwargs)
self._filter = filter
self._direction = direction
self._filter_system_frames = filter_system_frames
@@ -51,7 +57,7 @@ class FunctionFilter(FrameProcessor):
def _should_passthrough_frame(self, frame, direction):
"""Check if a frame should pass through without filtering."""
# Always passthrough frames in the wrong direction
if direction != self._direction:
if self._direction and direction != self._direction:
return True
# Always passthrough lifecycle frames

View File

@@ -14,10 +14,12 @@ from pipecat.frames.frames import (
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.filters.frame_filter import FrameFilter
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.filters.identity_filter import IdentityFilter
from pipecat.processors.filters.wake_check_filter import WakeCheckFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import run_test
@@ -93,6 +95,98 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
expected_down_frames=expected_down_frames,
)
async def test_no_direction_filters_both_directions(self):
"""When direction is None, frames in both directions are filtered."""
class UpstreamPusher(FrameProcessor):
"""Pushes a TextFrame upstream when it receives a system frame."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
await self.push_frame(TextFrame(text="upstream"), FrameDirection.UPSTREAM)
async def block_text(frame: Frame):
return not isinstance(frame, TextFrame)
# direction=None: filter applies in both directions. The downstream
# TextFrame is blocked and the upstream TextFrame pushed by
# UpstreamPusher is also blocked.
filter = FunctionFilter(filter=block_text, direction=None)
pipeline = Pipeline([filter, UpstreamPusher()])
frames_to_send = [
TextFrame(text="Hello!"),
UserStartedSpeakingFrame(),
]
expected_down_frames = [UserStartedSpeakingFrame]
expected_up_frames = []
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
async def test_downstream_direction_passes_upstream(self):
"""When direction is DOWNSTREAM, upstream frames pass through unfiltered."""
class UpstreamPusher(FrameProcessor):
"""Pushes a TextFrame upstream when it receives a system frame."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
await self.push_frame(TextFrame(text="upstream"), FrameDirection.UPSTREAM)
async def block_text(frame: Frame):
return not isinstance(frame, TextFrame)
# direction=DOWNSTREAM: filter only applies downstream, so the
# upstream TextFrame pushed by UpstreamPusher passes through.
filter = FunctionFilter(filter=block_text)
pipeline = Pipeline([filter, UpstreamPusher()])
frames_to_send = [UserStartedSpeakingFrame()]
expected_down_frames = [UserStartedSpeakingFrame]
expected_up_frames = [TextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
async def test_upstream_direction_passes_downstream(self):
"""When direction is UPSTREAM, downstream frames pass through unfiltered."""
class UpstreamPusher(FrameProcessor):
"""Pushes a TextFrame upstream when it receives a system frame."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
await self.push_frame(TextFrame(text="upstream"), FrameDirection.UPSTREAM)
async def block_text(frame: Frame):
return not isinstance(frame, TextFrame)
# direction=UPSTREAM: filter only applies upstream, so the
# downstream TextFrame passes through but the upstream TextFrame
# pushed by UpstreamPusher is blocked.
filter = FunctionFilter(filter=block_text, direction=FrameDirection.UPSTREAM)
pipeline = Pipeline([filter, UpstreamPusher()])
frames_to_send = [TextFrame(text="Hello!"), UserStartedSpeakingFrame()]
expected_down_frames = [UserStartedSpeakingFrame, TextFrame]
expected_up_frames = []
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
async def test_no_wake_word(self):

View File

@@ -6,6 +6,7 @@
"""Unit tests for ServiceSwitcher and related components."""
import asyncio
import unittest
from dataclasses import dataclass
@@ -122,14 +123,7 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
self.assertEqual(strategy.services, self.services)
self.assertEqual(strategy.active_service, self.service1) # First service should be active
def test_init_with_empty_services(self):
"""Test initialization with an empty list of services."""
strategy = ServiceSwitcherStrategyManual([])
self.assertEqual(strategy.services, [])
self.assertIsNone(strategy.active_service)
def test_handle_manually_switch_service_frame(self):
async def test_handle_manually_switch_service_frame(self):
"""Test manual service switching with ManuallySwitchServiceFrame."""
strategy = ServiceSwitcherStrategyManual(self.services)
@@ -139,7 +133,7 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
# Switch to service2
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertEqual(strategy.active_service, self.service2)
@@ -147,21 +141,66 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
# Switch to service3
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
self.assertEqual(strategy.active_service, self.service3)
def test_handle_frame_unsupported_frame_type(self):
async def test_on_service_switched_event(self):
"""Test that on_service_switched event fires with correct arguments."""
strategy = ServiceSwitcherStrategyManual(self.services)
switched_events = []
@strategy.event_handler("on_service_switched")
async def on_service_switched(strategy, service):
switched_events.append((strategy, service))
# Switch to service2
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
await asyncio.sleep(0) # Let async event task run
self.assertEqual(len(switched_events), 1)
self.assertIsInstance(switched_events[0][0], ServiceSwitcherStrategyManual)
self.assertEqual(switched_events[0][1], self.service2)
# Switch to service3
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
await asyncio.sleep(0)
self.assertEqual(len(switched_events), 2)
self.assertEqual(switched_events[1][1], self.service3)
async def test_on_service_switched_event_not_fired_for_unknown_service(self):
"""Test that on_service_switched event does not fire for services not in the list."""
strategy = ServiceSwitcherStrategyManual(self.services)
switched_events = []
@strategy.event_handler("on_service_switched")
async def on_service_switched(strategy, service):
switched_events.append(service)
# Try switching to a service not in the list
unknown_service = MockFrameProcessor("unknown")
switch_frame = ManuallySwitchServiceFrame(service=unknown_service)
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
await asyncio.sleep(0)
self.assertEqual(len(switched_events), 0)
self.assertEqual(strategy.active_service, self.service1) # Unchanged
async def test_handle_frame_unsupported_frame_type(self):
"""Test that unsupported frame types raise an error."""
strategy = ServiceSwitcherStrategyManual(self.services)
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
with self.assertRaises(ValueError) as context:
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
result = await strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
self.assertIn("Unsupported frame type", str(context.exception))
self.assertIsNone(result)
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
@@ -267,7 +306,7 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
ManuallySwitchServiceFrame(service=self.service2),
TextFrame("Hello 2"),
],
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
expected_down_frames=[TextFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
@@ -333,9 +372,7 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
],
expected_down_frames=[
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
],
expected_up_frames=[], # Expect no error frames
@@ -429,9 +466,8 @@ class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
expected_down_frames=[
MockMetadataFrame, # From startup (service1)
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
MockMetadataFrame, # From service2 after switch
TextFrame,
],
expected_up_frames=[],
)