Merge pull request #2907 from pipecat-ai/mb/service-switcher-updates

ServiceSwitcher updates
This commit is contained in:
kompfner
2025-10-22 11:23:48 -04:00
committed by GitHub
6 changed files with 287 additions and 20 deletions

View File

@@ -9,9 +9,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `FunctionFilter` now has a `filter_system_frames` arg, which controls whether
or not SystemFrames are filtered.
- Upgraded `aws_sdk_bedrock_runtime` to v0.1.1 to resolve potential CPU issues
when running `AWSNovaSonicLLMService`.
### Fixed
- Fixed an issue in `ServiceSwitcher` where the `STTService`s would result in
all STT services producing `TranscriptionFrame`s.
## [0.0.91] - 2025-10-21
### Added

View File

@@ -0,0 +1,153 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame, ManuallySwitchServiceFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.deepgram.tts import DeepgramTTSService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt_cartesia = CartesiaSTTService(api_key=os.getenv("CARTESIA_API_KEY"))
stt_deepgram = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt_switcher = ServiceSwitcher(
services=[stt_cartesia, stt_deepgram], strategy_type=ServiceSwitcherStrategyManual
)
tts_cartesia = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121",
)
tts_deepgram = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts_switcher = ServiceSwitcher(
services=[tts_cartesia, tts_deepgram], strategy_type=ServiceSwitcherStrategyManual
)
llm_openai = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
llm_google = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"))
llm_switcher = ServiceSwitcher(
services=[llm_openai, llm_google], strategy_type=ServiceSwitcherStrategyManual
)
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = LLMContext(messages)
context_aggregator = LLMContextAggregatorPair(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt_switcher,
context_aggregator.user(), # User responses
llm_switcher, # LLM
tts_switcher, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
await asyncio.sleep(15)
print(f"Switching to {stt_deepgram}")
await task.queue_frames([ManuallySwitchServiceFrame(service=stt_deepgram)])
await asyncio.sleep(15)
print(f"Switching to {llm_google}")
await task.queue_frames([ManuallySwitchServiceFrame(service=llm_google)])
await asyncio.sleep(15)
print(f"Switching to {tts_deepgram}")
await task.queue_frames([ManuallySwitchServiceFrame(service=tts_deepgram)])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -14,20 +14,41 @@ from pipecat.services.llm_service import LLMService
class LLMSwitcher(ServiceSwitcher[StrategyType]):
"""A pipeline that switches between different LLMs at runtime."""
"""A pipeline that switches between different LLMs at runtime.
Example::
llm_switcher = LLMSwitcher(
llms=[openai_llm, anthropic_llm],
strategy_type=ServiceSwitcherStrategyManual
)
"""
def __init__(self, llms: List[LLMService], strategy_type: Type[StrategyType]):
"""Initialize the service switcher with a list of LLMs and a switching strategy."""
"""Initialize the service switcher with a list of LLMs and a switching strategy.
Args:
llms: List of LLM services to switch between.
strategy_type: The strategy class to use for switching between LLMs.
"""
super().__init__(llms, strategy_type)
@property
def llms(self) -> List[LLMService]:
"""Get the list of LLMs managed by this switcher."""
"""Get the list of LLMs managed by this switcher.
Returns:
List of LLM services managed by this switcher.
"""
return self.services
@property
def active_llm(self) -> Optional[LLMService]:
"""Get the currently active LLM, if any."""
"""Get the currently active LLM.
Returns:
The currently active LLM service, or None if no LLM is active.
"""
return self.strategy.active_service
async def run_inference(self, context: LLMContext) -> Optional[str]:

View File

@@ -21,10 +21,22 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class ServiceSwitcherStrategy:
"""Base class for service switching strategies."""
"""Base class for service switching strategies.
Note:
Strategy classes are instantiated internally by ServiceSwitcher.
Developers should pass the strategy class (not an instance) to ServiceSwitcher.
"""
def __init__(self, services: List[FrameProcessor]):
"""Initialize the service switcher strategy with a list of services."""
"""Initialize the service switcher strategy with a list of services.
Note:
This is called internally by ServiceSwitcher. Do not instantiate directly.
Args:
services: List of frame processors to switch between.
"""
self.services = services
self.active_service: Optional[FrameProcessor] = None
@@ -46,10 +58,24 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
This strategy allows the user to manually select which service is active.
The initial active service is the first one in the list.
Example::
stt_switcher = ServiceSwitcher(
services=[stt_1, stt_2],
strategy_type=ServiceSwitcherStrategyManual
)
"""
def __init__(self, services: List[FrameProcessor]):
"""Initialize the manual service switcher strategy with a list of services."""
"""Initialize the manual service switcher strategy with a list of services.
Note:
This is called internally by ServiceSwitcher. Do not instantiate directly.
Args:
services: List of frame processors to switch between.
"""
super().__init__(services)
self.active_service = services[0] if services else None
@@ -85,7 +111,12 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
"""A pipeline that switches between different services at runtime."""
def __init__(self, services: List[FrameProcessor], strategy_type: Type[StrategyType]):
"""Initialize the service switcher with a list of services and a switching strategy."""
"""Initialize the service switcher with a list of services and a switching strategy.
Args:
services: List of frame processors to switch between.
strategy_type: The strategy class to use for switching between services.
"""
strategy = strategy_type(services)
super().__init__(*self._make_pipeline_definitions(services, strategy))
self.services = services
@@ -100,14 +131,20 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
active_service: FrameProcessor,
direction: FrameDirection,
):
"""Initialize the service switcher filter with a strategy and direction."""
"""Initialize the service switcher filter with a strategy and direction.
Args:
wrapped_service: The service that this filter wraps.
active_service: The currently active service.
direction: The direction of frame flow to filter.
"""
self._wrapped_service = wrapped_service
self._active_service = active_service
async def filter(_: Frame) -> bool:
return self._wrapped_service == self._active_service
super().__init__(filter, direction)
self._wrapped_service = wrapped_service
self._active_service = active_service
super().__init__(filter, direction, filter_system_frames=True)
async def process_frame(self, frame, direction):
"""Process a frame through the filter, handling special internal filter-updating frames."""

View File

@@ -12,7 +12,7 @@ allowing for flexible frame filtering logic in processing pipelines.
from typing import Awaitable, Callable
from pipecat.frames.frames import EndFrame, Frame, SystemFrame
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -28,6 +28,7 @@ class FunctionFilter(FrameProcessor):
self,
filter: Callable[[Frame], Awaitable[bool]],
direction: FrameDirection = FrameDirection.DOWNSTREAM,
filter_system_frames: bool = False,
):
"""Initialize the function filter.
@@ -36,22 +37,32 @@ class FunctionFilter(FrameProcessor):
frame should pass through, False otherwise.
direction: The direction to apply filtering. Only frames moving in
this direction will be filtered. Defaults to DOWNSTREAM.
filter_system_frames: Whether to filter system frames. Defaults to False.
"""
super().__init__()
self._filter = filter
self._direction = direction
self._filter_system_frames = filter_system_frames
#
# Frame processor
#
# Ignore system frames, end frames and frames that are not following the
# direction of this gate
def _should_passthrough_frame(self, frame, direction):
"""Check if a frame should pass through without filtering."""
# Ignore system frames, end frames and frames that are not following the
# direction of this gate
return isinstance(frame, (SystemFrame, EndFrame)) or direction != self._direction
# Always passthrough frames in the wrong direction
if direction != self._direction:
return True
# Always passthrough lifecycle frames
if isinstance(frame, (StartFrame, EndFrame, CancelFrame)):
return True
# If not filtering system frames, passthrough all other system frames
if not self._filter_system_frames and isinstance(frame, SystemFrame):
return True
return False
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame through the filter.

View File

@@ -7,10 +7,12 @@
"""Unit tests for ServiceSwitcher and related components."""
import unittest
from dataclasses import dataclass
from pipecat.frames.frames import (
Frame,
ManuallySwitchServiceFrame,
SystemFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
@@ -52,6 +54,13 @@ class MockFrameProcessor(FrameProcessor):
self.frame_count = 0
@dataclass
class DummySystemFrame(SystemFrame):
"""A dummy system frame for testing purposes."""
text: str = ""
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcherStrategyManual."""
@@ -140,14 +149,22 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
# Send some test frames
frames_to_send = [
TextFrame(text="Hello 1"),
DummySystemFrame(text="System Message 1"),
TextFrame(text="Hello 2"),
DummySystemFrame(text="System Message 2"),
TextFrame(text="Hello 3"),
]
await run_test(
switcher,
frames_to_send=frames_to_send,
expected_down_frames=[TextFrame, TextFrame, TextFrame],
expected_down_frames=[
DummySystemFrame,
DummySystemFrame,
TextFrame,
TextFrame,
TextFrame,
],
expected_up_frames=[], # Expect no error frames
)
@@ -156,7 +173,13 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
self.assertEqual(len(text_frames), 3)
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
# Only service1 should have processed the system frames
system_frames = [
f for f in self.service1.processed_frames if isinstance(f, DummySystemFrame)
]
self.assertEqual(len(system_frames), 2)
# Check that other services don't receive text frames (they still get StartFrame/EndFrame)
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
@@ -166,10 +189,24 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(service2_text_frames), 0)
self.assertEqual(len(service3_text_frames), 0)
# Check that other services don't receive dummy system frames (they still get StartFrame/EndFrame)
service2_system_frames = [
f for f in self.service2.processed_frames if isinstance(f, DummySystemFrame)
]
service3_system_frames = [
f for f in self.service3.processed_frames if isinstance(f, DummySystemFrame)
]
self.assertEqual(len(service2_system_frames), 0)
self.assertEqual(len(service3_system_frames), 0)
# Verify the actual text frames processed
for i, frame in enumerate(text_frames):
self.assertEqual(frame.text, f"Hello {i + 1}")
# Verify the actual system frames processed
for i, frame in enumerate(system_frames):
self.assertEqual(frame.text, f"System Message {i + 1}")
async def test_service_switching(self):
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)