Add ServiceSwitcherStrategyFailover for automatic failover on service errors (#3870)
* Add ServiceSwitcherStrategyFailover for automatic error-based service switching Introduce a strategy hierarchy: ServiceSwitcherStrategy (base) → ServiceSwitcherStrategyManual (handles ManuallySwitchServiceFrame) → ServiceSwitcherStrategyFailover (adds error-based failover). ServiceSwitcher now defaults to ServiceSwitcherStrategyManual with strategy_type optional. Non-fatal ErrorFrames are forwarded to the strategy via handle_error(). * Move metadata request into _set_active_if_available Requesting metadata is part of making a service active, so it belongs alongside setting _active_service and firing on_service_switched. This removes the duplicate queue_frame calls from ServiceSwitcher push_frame and process_frame.
This commit is contained in:
1
changelog/3861.added.md
Normal file
1
changelog/3861.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `ServiceSwitcherStrategyFailover` that automatically switches to the next service when the active service reports a non-fatal error. Recovery policies can be implemented via the `on_service_switched` event handler.
|
||||
1
changelog/3861.changed.md
Normal file
1
changelog/3861.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- `ServiceSwitcherStrategy` base class now provides a `handle_error()` hook for subclasses to implement error-based switching. `ServiceSwitcher` defaults to `ServiceSwitcherStrategyManual` and `strategy_type` is now optional.
|
||||
@@ -17,7 +17,7 @@ from pipecat.frames.frames import LLMRunFrame, ManuallySwitchServiceFrame
|
||||
from pipecat.pipeline.llm_switcher import LLMSwitcher
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
@@ -96,9 +96,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
stt_cartesia = CartesiaSTTService(api_key=os.getenv("CARTESIA_API_KEY"))
|
||||
stt_deepgram = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt_switcher = ServiceSwitcher(
|
||||
services=[stt_cartesia, stt_deepgram], strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
# Uses ServiceSwitcherStrategyManual by default
|
||||
stt_switcher = ServiceSwitcher(services=[stt_cartesia, stt_deepgram])
|
||||
|
||||
tts_cartesia = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
@@ -112,9 +111,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
voice="aura-2-helena-en",
|
||||
),
|
||||
)
|
||||
tts_switcher = ServiceSwitcher(
|
||||
services=[tts_cartesia, tts_deepgram], strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
# Uses ServiceSwitcherStrategyManual by default
|
||||
tts_switcher = ServiceSwitcher(services=[tts_cartesia, tts_deepgram])
|
||||
|
||||
system_prompt = "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way."
|
||||
|
||||
@@ -126,9 +124,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
settings=GoogleLLMService.Settings(system_instruction=system_prompt),
|
||||
)
|
||||
llm_switcher = LLMSwitcher(
|
||||
llms=[llm_openai, llm_google], strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
# Uses ServiceSwitcherStrategyManual by default
|
||||
llm_switcher = LLMSwitcher(llms=[llm_openai, llm_google])
|
||||
# Register a "classic" function
|
||||
llm_switcher.register_function("get_current_weather", fetch_weather_from_api)
|
||||
# Register a "direct" function
|
||||
|
||||
@@ -9,7 +9,11 @@
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunction
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, StrategyType
|
||||
from pipecat.pipeline.service_switcher import (
|
||||
ServiceSwitcher,
|
||||
ServiceSwitcherStrategyManual,
|
||||
StrategyType,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
@@ -19,18 +23,20 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
|
||||
Example::
|
||||
|
||||
llm_switcher = LLMSwitcher(
|
||||
llms=[openai_llm, anthropic_llm],
|
||||
strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
llm_switcher = LLMSwitcher(llms=[openai_llm, anthropic_llm])
|
||||
"""
|
||||
|
||||
def __init__(self, llms: List[LLMService], strategy_type: Type[StrategyType]):
|
||||
def __init__(
|
||||
self,
|
||||
llms: List[LLMService],
|
||||
strategy_type: Type[StrategyType] = ServiceSwitcherStrategyManual,
|
||||
):
|
||||
"""Initialize the service switcher with a list of LLMs and a switching strategy.
|
||||
|
||||
Args:
|
||||
llms: List of LLM services to switch between.
|
||||
strategy_type: The strategy class to use for switching between LLMs.
|
||||
Defaults to ``ServiceSwitcherStrategyManual``.
|
||||
"""
|
||||
super().__init__(llms, strategy_type)
|
||||
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
|
||||
"""Service switcher for switching between different services at runtime, with different switching strategies."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
ServiceMetadataFrame,
|
||||
@@ -69,13 +71,13 @@ class ServiceSwitcherStrategy(BaseObject):
|
||||
"""Return the currently active service."""
|
||||
return self._active_service
|
||||
|
||||
@abstractmethod
|
||||
async def handle_frame(
|
||||
self, frame: ServiceSwitcherFrame, direction: FrameDirection
|
||||
) -> Optional[FrameProcessor]:
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
Subclasses implement this to decide whether a switch should occur.
|
||||
The base implementation returns ``None`` for all frames. Subclasses
|
||||
override this to implement specific switching behaviors.
|
||||
|
||||
Args:
|
||||
frame: The frame to handle.
|
||||
@@ -84,7 +86,41 @@ class ServiceSwitcherStrategy(BaseObject):
|
||||
Returns:
|
||||
The newly active service if a switch occurred, or None otherwise.
|
||||
"""
|
||||
pass
|
||||
return None
|
||||
|
||||
async def handle_error(self, error: ErrorFrame) -> Optional[FrameProcessor]:
|
||||
"""Handle an error from the active service.
|
||||
|
||||
Called by ``ServiceSwitcher`` when a non-fatal ``ErrorFrame`` is pushed
|
||||
upstream by the currently active service. Subclasses can override this
|
||||
to implement automatic failover.
|
||||
|
||||
Args:
|
||||
error: The error frame pushed by the active service.
|
||||
|
||||
Returns:
|
||||
The newly active service if a switch occurred, or None otherwise.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def _set_active_if_available(self, service: FrameProcessor) -> Optional[FrameProcessor]:
|
||||
"""Set the active service to the given one, if it is in the list of available services.
|
||||
|
||||
If it's not in the list, the request is ignored, as it may have been
|
||||
intended for another ServiceSwitcher in the pipeline.
|
||||
|
||||
Args:
|
||||
service: The service to set as active.
|
||||
|
||||
Returns:
|
||||
The newly active service, or None if the service was not found.
|
||||
"""
|
||||
if service in self.services:
|
||||
self._active_service = service
|
||||
await service.queue_frame(ServiceSwitcherRequestMetadataFrame(service=service))
|
||||
await self._call_event_handler("on_service_switched", service)
|
||||
return service
|
||||
return None
|
||||
|
||||
|
||||
class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
@@ -118,23 +154,54 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
|
||||
return None
|
||||
|
||||
async def _set_active_if_available(self, service: FrameProcessor) -> Optional[FrameProcessor]:
|
||||
"""Set the active service to the given one, if it is in the list of available services.
|
||||
|
||||
If it's not in the list, the request is ignored, as it may have been
|
||||
intended for another ServiceSwitcher in the pipeline.
|
||||
class ServiceSwitcherStrategyFailover(ServiceSwitcherStrategyManual):
|
||||
"""A strategy that automatically switches to a backup service on failure.
|
||||
|
||||
When the active service produces a non-fatal error, this strategy switches
|
||||
to the next available service in the list. Recovery and fallback policies
|
||||
are left to application code via the ``on_service_switched`` event.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_service_switched: Called when the active service changes.
|
||||
|
||||
Example::
|
||||
|
||||
switcher = ServiceSwitcher(
|
||||
services=[primary_stt, backup_stt],
|
||||
strategy_type=ServiceSwitcherStrategyFailover,
|
||||
)
|
||||
|
||||
@switcher.strategy.event_handler("on_service_switched")
|
||||
async def on_switched(strategy, service):
|
||||
# App decides when/how to recover the failed service
|
||||
...
|
||||
"""
|
||||
|
||||
async def handle_error(self, error: ErrorFrame) -> Optional[FrameProcessor]:
|
||||
"""Handle an error from the active service by failing over.
|
||||
|
||||
Switches to the next service in the list. The failed service remains
|
||||
in the list and can be switched back to manually or via application
|
||||
logic in the ``on_service_switched`` event handler.
|
||||
|
||||
Args:
|
||||
service: The service to set as active.
|
||||
error: The error frame pushed by the active service.
|
||||
|
||||
Returns:
|
||||
The newly active service, or None if the service was not found.
|
||||
The newly active service if a switch occurred, or None if no
|
||||
other service is available.
|
||||
"""
|
||||
if service in self.services:
|
||||
self._active_service = service
|
||||
await self._call_event_handler("on_service_switched", service)
|
||||
return service
|
||||
return None
|
||||
logger.warning(f"Service {self._active_service.name} reported an error: {error.error}")
|
||||
|
||||
if len(self._services) <= 1:
|
||||
logger.error("No other service available to switch to")
|
||||
return None
|
||||
|
||||
current_idx = self._services.index(self._active_service)
|
||||
next_idx = (current_idx + 1) % len(self._services)
|
||||
return await self._set_active_if_available(self._services[next_idx])
|
||||
|
||||
|
||||
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
|
||||
@@ -150,18 +217,20 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
|
||||
Example::
|
||||
|
||||
switcher = ServiceSwitcher(
|
||||
services=[stt_1, stt_2],
|
||||
strategy_type=ServiceSwitcherStrategyManual,
|
||||
)
|
||||
switcher = ServiceSwitcher(services=[stt_1, stt_2])
|
||||
"""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor], strategy_type: Type[StrategyType]):
|
||||
def __init__(
|
||||
self,
|
||||
services: List[FrameProcessor],
|
||||
strategy_type: Type[StrategyType] = ServiceSwitcherStrategyManual,
|
||||
):
|
||||
"""Initialize the service switcher with a list of services and a switching strategy.
|
||||
|
||||
Args:
|
||||
services: List of frame processors to switch between.
|
||||
strategy_type: The strategy class to use for switching between services.
|
||||
Defaults to ``ServiceSwitcherStrategyManual``.
|
||||
"""
|
||||
_strategy = strategy_type(services)
|
||||
super().__init__(*self._make_pipeline_definitions(services, _strategy))
|
||||
@@ -227,6 +296,10 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
all the filters let it pass, and `StartFrame` causes the service to
|
||||
generate `ServiceMetadataFrame`.
|
||||
|
||||
Non-fatal ``ErrorFrame`` instances are forwarded to the strategy via
|
||||
``handle_error`` so strategies like ``ServiceSwitcherStrategyFailover``
|
||||
can perform failover. The error frame is still propagated upstream so
|
||||
that application-level error handlers can observe it.
|
||||
"""
|
||||
# Consume ServiceSwitcherRequestMetadataFrame once the targeted service
|
||||
# has handled it (i.e. the active service).
|
||||
@@ -239,6 +312,10 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
if frame.service_name != self.strategy.active_service.name:
|
||||
return
|
||||
|
||||
# Let the strategy react to non-fatal errors from the active service.
|
||||
if isinstance(frame, ErrorFrame) and not frame.fatal:
|
||||
await self.strategy.handle_error(frame)
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -255,9 +332,5 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
# frame. If we switched, we just swallow the frame.
|
||||
if not service:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we switched to a new service, request its metadata.
|
||||
if service:
|
||||
await service.queue_frame(ServiceSwitcherRequestMetadataFrame(service=service))
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -11,6 +11,7 @@ import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
ServiceMetadataFrame,
|
||||
@@ -20,7 +21,12 @@ from pipecat.frames.frames import (
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
|
||||
from pipecat.pipeline.service_switcher import (
|
||||
ServiceSwitcher,
|
||||
ServiceSwitcherStrategy,
|
||||
ServiceSwitcherStrategyFailover,
|
||||
ServiceSwitcherStrategyManual,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
@@ -106,8 +112,8 @@ class DummySystemFrame(SystemFrame):
|
||||
text: str = ""
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
class TestServiceSwitcherStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for the base ServiceSwitcherStrategy."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
@@ -118,10 +124,54 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def test_init_with_services(self):
|
||||
"""Test initialization with a list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
strategy = ServiceSwitcherStrategy(self.services)
|
||||
|
||||
self.assertEqual(strategy.services, self.services)
|
||||
self.assertEqual(strategy.active_service, self.service1) # First service should be active
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
|
||||
async def test_handle_frame_returns_none_for_manual_switch(self):
|
||||
"""Test that base strategy does not handle ManuallySwitchServiceFrame."""
|
||||
strategy = ServiceSwitcherStrategy(self.services)
|
||||
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
result = await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
|
||||
async def test_handle_frame_returns_none_for_unsupported_frame(self):
|
||||
"""Test that unsupported frame types return None."""
|
||||
strategy = ServiceSwitcherStrategy(self.services)
|
||||
unsupported_frame = TextFrame(text="test")
|
||||
|
||||
result = await strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_handle_error_returns_none(self):
|
||||
"""Test that handle_error returns None by default."""
|
||||
strategy = ServiceSwitcherStrategy(self.services)
|
||||
|
||||
result = await strategy.handle_error(ErrorFrame(error="error"))
|
||||
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_is_subclass_of_base_strategy(self):
|
||||
"""Test that ServiceSwitcherStrategyManual is a subclass of ServiceSwitcherStrategy."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
self.assertIsInstance(strategy, ServiceSwitcherStrategy)
|
||||
|
||||
async def test_handle_manually_switch_service_frame(self):
|
||||
"""Test manual service switching with ManuallySwitchServiceFrame."""
|
||||
@@ -129,22 +179,15 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Initially service1 should be active
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
self.assertNotEqual(strategy.active_service, self.service3)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
async def test_on_service_switched_event(self):
|
||||
@@ -157,25 +200,16 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
async def on_service_switched(strategy, service):
|
||||
switched_events.append((strategy, service))
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await asyncio.sleep(0) # Let async event task run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(len(switched_events), 1)
|
||||
self.assertIsInstance(switched_events[0][0], ServiceSwitcherStrategyManual)
|
||||
self.assertEqual(switched_events[0][1], self.service2)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(len(switched_events), 2)
|
||||
self.assertEqual(switched_events[1][1], self.service3)
|
||||
|
||||
async def test_on_service_switched_event_not_fired_for_unknown_service(self):
|
||||
"""Test that on_service_switched event does not fire for services not in the list."""
|
||||
async def test_unknown_service_ignored(self):
|
||||
"""Test that switching to an unknown service is ignored."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
switched_events = []
|
||||
@@ -184,23 +218,14 @@ class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
async def on_service_switched(strategy, service):
|
||||
switched_events.append(service)
|
||||
|
||||
# Try switching to a service not in the list
|
||||
unknown_service = MockFrameProcessor("unknown")
|
||||
switch_frame = ManuallySwitchServiceFrame(service=unknown_service)
|
||||
await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
result = await strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(len(switched_events), 0)
|
||||
self.assertEqual(strategy.active_service, self.service1) # Unchanged
|
||||
|
||||
async def test_handle_frame_unsupported_frame_type(self):
|
||||
"""Test that unsupported frame types raise an error."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
|
||||
|
||||
result = await strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(len(switched_events), 0)
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
|
||||
|
||||
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -213,9 +238,9 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_manual_strategy(self):
|
||||
"""Test initialization with manual strategy."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
def test_init_with_default_strategy(self):
|
||||
"""Test initialization with default strategy."""
|
||||
switcher = ServiceSwitcher(self.services)
|
||||
|
||||
self.assertEqual(switcher.services, self.services)
|
||||
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
|
||||
@@ -223,7 +248,7 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_default_active_service(self):
|
||||
"""Test that the initially-active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
switcher = ServiceSwitcher(self.services)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
@@ -292,7 +317,7 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_service_switching(self):
|
||||
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
switcher = ServiceSwitcher(self.services)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
@@ -341,8 +366,8 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
switcher2_services = [switcher2_service1, switcher2_service2]
|
||||
|
||||
# Create two service switchers
|
||||
switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual)
|
||||
switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual)
|
||||
switcher1 = ServiceSwitcher(switcher1_services)
|
||||
switcher2 = ServiceSwitcher(switcher2_services)
|
||||
|
||||
# Create a pipeline with both switchers: switcher1 -> switcher2
|
||||
pipeline = Pipeline([switcher1, switcher2])
|
||||
@@ -428,7 +453,7 @@ class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_only_active_service_metadata_at_startup(self):
|
||||
"""Test that only the active service's metadata leaves the ServiceSwitcher at startup."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
switcher = ServiceSwitcher(self.services)
|
||||
|
||||
# Run the pipeline (StartFrame triggers metadata emission)
|
||||
output_frames = []
|
||||
@@ -450,7 +475,7 @@ class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_metadata_emitted_on_service_switch(self):
|
||||
"""Test that switching services triggers metadata emission from the new active service."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
switcher = ServiceSwitcher(self.services)
|
||||
|
||||
# Reset counters after startup
|
||||
self.service1.reset_counters()
|
||||
@@ -482,7 +507,7 @@ class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_inactive_service_metadata_blocked(self):
|
||||
"""Test that metadata from inactive services is blocked."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
switcher = ServiceSwitcher(self.services)
|
||||
|
||||
# Run and collect output frames
|
||||
await run_test(
|
||||
@@ -497,5 +522,80 @@ class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
|
||||
# Only one MockMetadataFrame should have left (from service1)
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyFailover(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyFailover."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_defaults(self):
|
||||
"""Test that default values are set correctly."""
|
||||
strategy = ServiceSwitcherStrategyFailover(self.services)
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
|
||||
async def test_error_switches_to_next_service(self):
|
||||
"""Test that an error on the active service switches to the next one."""
|
||||
strategy = ServiceSwitcherStrategyFailover(self.services)
|
||||
|
||||
error = ErrorFrame(error="connection lost")
|
||||
result = await strategy.handle_error(error)
|
||||
|
||||
self.assertEqual(result, self.service2)
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
|
||||
async def test_consecutive_errors_cycle_through_services(self):
|
||||
"""Test that repeated errors cycle through all services."""
|
||||
strategy = ServiceSwitcherStrategyFailover(self.services)
|
||||
|
||||
# First error: service1 -> service2
|
||||
await strategy.handle_error(ErrorFrame(error="error 1"))
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
|
||||
# Second error: service2 -> service3
|
||||
await strategy.handle_error(ErrorFrame(error="error 2"))
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
# Third error: service3 -> service1 (wraps around)
|
||||
await strategy.handle_error(ErrorFrame(error="error 3"))
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
|
||||
async def test_single_service_returns_none(self):
|
||||
"""Test that handle_error returns None with only one service."""
|
||||
strategy = ServiceSwitcherStrategyFailover([self.service1])
|
||||
|
||||
result = await strategy.handle_error(ErrorFrame(error="error"))
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_manual_switch_still_works(self):
|
||||
"""Test that ManuallySwitchServiceFrame is still handled."""
|
||||
strategy = ServiceSwitcherStrategyFailover(self.services)
|
||||
|
||||
frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
result = await strategy.handle_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertEqual(result, self.service3)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
async def test_on_service_switched_event_fires_on_error(self):
|
||||
"""Test that on_service_switched event fires when an error triggers a switch."""
|
||||
strategy = ServiceSwitcherStrategyFailover(self.services)
|
||||
|
||||
switched_events = []
|
||||
|
||||
@strategy.event_handler("on_service_switched")
|
||||
async def on_service_switched(strategy, service):
|
||||
switched_events.append(service)
|
||||
|
||||
await strategy.handle_error(ErrorFrame(error="error"))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(len(switched_events), 1)
|
||||
self.assertEqual(switched_events[0], self.service2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user