Fix ServiceSwitcher reacting to pass-through ErrorFrames from other pipeline stages
ErrorFrames propagating upstream from downstream processors (e.g. TTS) would enter the ServiceSwitcher via process_frame, traverse the active service sub-pipeline, and reach push_frame where they incorrectly triggered failover. Now only errors whose processor is one of the managed services trigger handle_error. Also fix the log in handle_error to attribute errors to the actual source processor rather than the current active_service. Closes #4139
This commit is contained in:
@@ -193,7 +193,8 @@ class ServiceSwitcherStrategyFailover(ServiceSwitcherStrategyManual):
|
||||
The newly active service if a switch occurred, or None if no
|
||||
other service is available.
|
||||
"""
|
||||
logger.warning(f"Service {self._active_service.name} reported an error: {error.error}")
|
||||
service_name = error.processor.name if error.processor else self._active_service.name
|
||||
logger.warning(f"Service {service_name} reported an error: {error.error}")
|
||||
|
||||
if len(self._services) <= 1:
|
||||
logger.error("No other service available to switch to")
|
||||
@@ -313,8 +314,12 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
return
|
||||
|
||||
# Let the strategy react to non-fatal errors from the active service.
|
||||
# We check that the error originated from one of our managed services
|
||||
# to avoid reacting to errors that are just propagating upstream
|
||||
# through the pipeline from downstream processors.
|
||||
if isinstance(frame, ErrorFrame) and not frame.fatal:
|
||||
await self.strategy.handle_error(frame)
|
||||
if frame.processor and frame.processor in self._services:
|
||||
await self.strategy.handle_error(frame)
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -105,6 +105,38 @@ class MockMetadataService(FrameProcessor):
|
||||
self.metadata_push_count = 0
|
||||
|
||||
|
||||
class ErrorInjectorProcessor(FrameProcessor):
|
||||
"""A downstream processor that pushes an ErrorFrame upstream on receiving a TextFrame.
|
||||
|
||||
Simulates an error from a service outside the ServiceSwitcher (e.g. TTS
|
||||
erroring while propagating upstream through an LLM switcher).
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
await self.push_error("downstream service error")
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class ErrorOnTextService(FrameProcessor):
|
||||
"""A mock service that pushes an error on the first TextFrame it receives.
|
||||
|
||||
Simulates a managed service inside a ServiceSwitcher that encounters an error.
|
||||
"""
|
||||
|
||||
def __init__(self, test_name: str, **kwargs):
|
||||
super().__init__(name=test_name, **kwargs)
|
||||
self._errored = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame) and not self._errored:
|
||||
self._errored = True
|
||||
await self.push_error("service connection lost")
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySystemFrame(SystemFrame):
|
||||
"""A dummy system frame for testing purposes."""
|
||||
@@ -580,6 +612,49 @@ class TestServiceSwitcherStrategyFailover(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(result, self.service3)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
async def test_passthrough_error_does_not_trigger_failover(self):
|
||||
"""Test that an error propagating upstream from a downstream processor does not trigger failover.
|
||||
|
||||
This reproduces the bug where an ErrorFrame from e.g. TTS propagates
|
||||
upstream through an LLM ServiceSwitcher and incorrectly triggers
|
||||
failover even though neither LLM service produced the error.
|
||||
"""
|
||||
switcher = ServiceSwitcher(
|
||||
[self.service1, self.service2],
|
||||
strategy_type=ServiceSwitcherStrategyFailover,
|
||||
)
|
||||
error_injector = ErrorInjectorProcessor()
|
||||
pipeline = Pipeline([switcher, error_injector])
|
||||
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[TextFrame(text="test")],
|
||||
expected_down_frames=[TextFrame],
|
||||
expected_up_frames=[ErrorFrame],
|
||||
)
|
||||
|
||||
# Active service should NOT have changed — the error came from outside
|
||||
self.assertEqual(switcher.strategy.active_service, self.service1)
|
||||
|
||||
async def test_managed_service_error_triggers_failover(self):
|
||||
"""Test that an error from a managed service inside the switcher triggers failover."""
|
||||
error_service = ErrorOnTextService("error_service")
|
||||
backup_service = MockFrameProcessor("backup_service")
|
||||
switcher = ServiceSwitcher(
|
||||
[error_service, backup_service],
|
||||
strategy_type=ServiceSwitcherStrategyFailover,
|
||||
)
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[TextFrame(text="test")],
|
||||
expected_down_frames=[TextFrame],
|
||||
expected_up_frames=[ErrorFrame],
|
||||
)
|
||||
|
||||
# Active service SHOULD have changed — the error came from a managed service
|
||||
self.assertEqual(switcher.strategy.active_service, backup_service)
|
||||
|
||||
async def test_on_service_switched_event_fires_on_error(self):
|
||||
"""Test that on_service_switched event fires when an error triggers a switch."""
|
||||
strategy = ServiceSwitcherStrategyFailover(self.services)
|
||||
|
||||
Reference in New Issue
Block a user