diff --git a/src/pipecat/pipeline/service_switcher.py b/src/pipecat/pipeline/service_switcher.py index 76b703681..91a665236 100644 --- a/src/pipecat/pipeline/service_switcher.py +++ b/src/pipecat/pipeline/service_switcher.py @@ -193,7 +193,8 @@ class ServiceSwitcherStrategyFailover(ServiceSwitcherStrategyManual): The newly active service if a switch occurred, or None if no other service is available. """ - logger.warning(f"Service {self._active_service.name} reported an error: {error.error}") + service_name = error.processor.name if error.processor else self._active_service.name + logger.warning(f"Service {service_name} reported an error: {error.error}") if len(self._services) <= 1: logger.error("No other service available to switch to") @@ -313,8 +314,12 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]): return # Let the strategy react to non-fatal errors from the active service. + # We check that the error originated from one of our managed services + # to avoid reacting to errors that are just propagating upstream + # through the pipeline from downstream processors. if isinstance(frame, ErrorFrame) and not frame.fatal: - await self.strategy.handle_error(frame) + if frame.processor and frame.processor in self._services: + await self.strategy.handle_error(frame) await super().push_frame(frame, direction) diff --git a/tests/test_service_switcher.py b/tests/test_service_switcher.py index 4cdf38f1a..0ccea6e20 100644 --- a/tests/test_service_switcher.py +++ b/tests/test_service_switcher.py @@ -105,6 +105,38 @@ class MockMetadataService(FrameProcessor): self.metadata_push_count = 0 +class ErrorInjectorProcessor(FrameProcessor): + """A downstream processor that pushes an ErrorFrame upstream on receiving a TextFrame. + + Simulates an error from a service outside the ServiceSwitcher (e.g. TTS + erroring while propagating upstream through an LLM switcher). + """ + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, TextFrame) and direction == FrameDirection.DOWNSTREAM: + await self.push_error("downstream service error") + await self.push_frame(frame, direction) + + +class ErrorOnTextService(FrameProcessor): + """A mock service that pushes an error on the first TextFrame it receives. + + Simulates a managed service inside a ServiceSwitcher that encounters an error. + """ + + def __init__(self, test_name: str, **kwargs): + super().__init__(name=test_name, **kwargs) + self._errored = False + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, TextFrame) and not self._errored: + self._errored = True + await self.push_error("service connection lost") + await self.push_frame(frame, direction) + + @dataclass class DummySystemFrame(SystemFrame): """A dummy system frame for testing purposes.""" @@ -580,6 +612,49 @@ class TestServiceSwitcherStrategyFailover(unittest.IsolatedAsyncioTestCase): self.assertEqual(result, self.service3) self.assertEqual(strategy.active_service, self.service3) + async def test_passthrough_error_does_not_trigger_failover(self): + """Test that an error propagating upstream from a downstream processor does not trigger failover. + + This reproduces the bug where an ErrorFrame from e.g. TTS propagates + upstream through an LLM ServiceSwitcher and incorrectly triggers + failover even though neither LLM service produced the error. + """ + switcher = ServiceSwitcher( + [self.service1, self.service2], + strategy_type=ServiceSwitcherStrategyFailover, + ) + error_injector = ErrorInjectorProcessor() + pipeline = Pipeline([switcher, error_injector]) + + await run_test( + pipeline, + frames_to_send=[TextFrame(text="test")], + expected_down_frames=[TextFrame], + expected_up_frames=[ErrorFrame], + ) + + # Active service should NOT have changed — the error came from outside + self.assertEqual(switcher.strategy.active_service, self.service1) + + async def test_managed_service_error_triggers_failover(self): + """Test that an error from a managed service inside the switcher triggers failover.""" + error_service = ErrorOnTextService("error_service") + backup_service = MockFrameProcessor("backup_service") + switcher = ServiceSwitcher( + [error_service, backup_service], + strategy_type=ServiceSwitcherStrategyFailover, + ) + + await run_test( + switcher, + frames_to_send=[TextFrame(text="test")], + expected_down_frames=[TextFrame], + expected_up_frames=[ErrorFrame], + ) + + # Active service SHOULD have changed — the error came from a managed service + self.assertEqual(switcher.strategy.active_service, backup_service) + async def test_on_service_switched_event_fires_on_error(self): """Test that on_service_switched event fires when an error triggers a switch.""" strategy = ServiceSwitcherStrategyFailover(self.services)