Merge pull request #2907 from pipecat-ai/mb/service-switcher-updates
ServiceSwitcher updates
This commit is contained in:
@@ -9,9 +9,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `FunctionFilter` now has a `filter_system_frames` arg, which controls whether
|
||||
or not SystemFrames are filtered.
|
||||
|
||||
- Upgraded `aws_sdk_bedrock_runtime` to v0.1.1 to resolve potential CPU issues
|
||||
when running `AWSNovaSonicLLMService`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `ServiceSwitcher` where the `STTService`s would result in
|
||||
all STT services producing `TranscriptionFrame`s.
|
||||
|
||||
## [0.0.91] - 2025-10-21
|
||||
|
||||
### Added
|
||||
|
||||
153
examples/foundational/48-service-switcher.py
Normal file
153
examples/foundational/48-service-switcher.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, ManuallySwitchServiceFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt_cartesia = CartesiaSTTService(api_key=os.getenv("CARTESIA_API_KEY"))
|
||||
stt_deepgram = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt_switcher = ServiceSwitcher(
|
||||
services=[stt_cartesia, stt_deepgram], strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
|
||||
tts_cartesia = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121",
|
||||
)
|
||||
tts_deepgram = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
tts_switcher = ServiceSwitcher(
|
||||
services=[tts_cartesia, tts_deepgram], strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
|
||||
llm_openai = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm_google = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
llm_switcher = ServiceSwitcher(
|
||||
services=[llm_openai, llm_google], strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt_switcher,
|
||||
context_aggregator.user(), # User responses
|
||||
llm_switcher, # LLM
|
||||
tts_switcher, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
await asyncio.sleep(15)
|
||||
print(f"Switching to {stt_deepgram}")
|
||||
await task.queue_frames([ManuallySwitchServiceFrame(service=stt_deepgram)])
|
||||
await asyncio.sleep(15)
|
||||
print(f"Switching to {llm_google}")
|
||||
await task.queue_frames([ManuallySwitchServiceFrame(service=llm_google)])
|
||||
await asyncio.sleep(15)
|
||||
print(f"Switching to {tts_deepgram}")
|
||||
await task.queue_frames([ManuallySwitchServiceFrame(service=tts_deepgram)])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -14,20 +14,41 @@ from pipecat.services.llm_service import LLMService
|
||||
|
||||
|
||||
class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
"""A pipeline that switches between different LLMs at runtime."""
|
||||
"""A pipeline that switches between different LLMs at runtime.
|
||||
|
||||
Example::
|
||||
|
||||
llm_switcher = LLMSwitcher(
|
||||
llms=[openai_llm, anthropic_llm],
|
||||
strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, llms: List[LLMService], strategy_type: Type[StrategyType]):
|
||||
"""Initialize the service switcher with a list of LLMs and a switching strategy."""
|
||||
"""Initialize the service switcher with a list of LLMs and a switching strategy.
|
||||
|
||||
Args:
|
||||
llms: List of LLM services to switch between.
|
||||
strategy_type: The strategy class to use for switching between LLMs.
|
||||
"""
|
||||
super().__init__(llms, strategy_type)
|
||||
|
||||
@property
|
||||
def llms(self) -> List[LLMService]:
|
||||
"""Get the list of LLMs managed by this switcher."""
|
||||
"""Get the list of LLMs managed by this switcher.
|
||||
|
||||
Returns:
|
||||
List of LLM services managed by this switcher.
|
||||
"""
|
||||
return self.services
|
||||
|
||||
@property
|
||||
def active_llm(self) -> Optional[LLMService]:
|
||||
"""Get the currently active LLM, if any."""
|
||||
"""Get the currently active LLM.
|
||||
|
||||
Returns:
|
||||
The currently active LLM service, or None if no LLM is active.
|
||||
"""
|
||||
return self.strategy.active_service
|
||||
|
||||
async def run_inference(self, context: LLMContext) -> Optional[str]:
|
||||
|
||||
@@ -21,10 +21,22 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class ServiceSwitcherStrategy:
|
||||
"""Base class for service switching strategies."""
|
||||
"""Base class for service switching strategies.
|
||||
|
||||
Note:
|
||||
Strategy classes are instantiated internally by ServiceSwitcher.
|
||||
Developers should pass the strategy class (not an instance) to ServiceSwitcher.
|
||||
"""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor]):
|
||||
"""Initialize the service switcher strategy with a list of services."""
|
||||
"""Initialize the service switcher strategy with a list of services.
|
||||
|
||||
Note:
|
||||
This is called internally by ServiceSwitcher. Do not instantiate directly.
|
||||
|
||||
Args:
|
||||
services: List of frame processors to switch between.
|
||||
"""
|
||||
self.services = services
|
||||
self.active_service: Optional[FrameProcessor] = None
|
||||
|
||||
@@ -46,10 +58,24 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
|
||||
This strategy allows the user to manually select which service is active.
|
||||
The initial active service is the first one in the list.
|
||||
|
||||
Example::
|
||||
|
||||
stt_switcher = ServiceSwitcher(
|
||||
services=[stt_1, stt_2],
|
||||
strategy_type=ServiceSwitcherStrategyManual
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor]):
|
||||
"""Initialize the manual service switcher strategy with a list of services."""
|
||||
"""Initialize the manual service switcher strategy with a list of services.
|
||||
|
||||
Note:
|
||||
This is called internally by ServiceSwitcher. Do not instantiate directly.
|
||||
|
||||
Args:
|
||||
services: List of frame processors to switch between.
|
||||
"""
|
||||
super().__init__(services)
|
||||
self.active_service = services[0] if services else None
|
||||
|
||||
@@ -85,7 +111,12 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
"""A pipeline that switches between different services at runtime."""
|
||||
|
||||
def __init__(self, services: List[FrameProcessor], strategy_type: Type[StrategyType]):
|
||||
"""Initialize the service switcher with a list of services and a switching strategy."""
|
||||
"""Initialize the service switcher with a list of services and a switching strategy.
|
||||
|
||||
Args:
|
||||
services: List of frame processors to switch between.
|
||||
strategy_type: The strategy class to use for switching between services.
|
||||
"""
|
||||
strategy = strategy_type(services)
|
||||
super().__init__(*self._make_pipeline_definitions(services, strategy))
|
||||
self.services = services
|
||||
@@ -100,14 +131,20 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
active_service: FrameProcessor,
|
||||
direction: FrameDirection,
|
||||
):
|
||||
"""Initialize the service switcher filter with a strategy and direction."""
|
||||
"""Initialize the service switcher filter with a strategy and direction.
|
||||
|
||||
Args:
|
||||
wrapped_service: The service that this filter wraps.
|
||||
active_service: The currently active service.
|
||||
direction: The direction of frame flow to filter.
|
||||
"""
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
|
||||
async def filter(_: Frame) -> bool:
|
||||
return self._wrapped_service == self._active_service
|
||||
|
||||
super().__init__(filter, direction)
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
super().__init__(filter, direction, filter_system_frames=True)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process a frame through the filter, handling special internal filter-updating frames."""
|
||||
|
||||
@@ -12,7 +12,7 @@ allowing for flexible frame filtering logic in processing pipelines.
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from pipecat.frames.frames import EndFrame, Frame, SystemFrame
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class FunctionFilter(FrameProcessor):
|
||||
self,
|
||||
filter: Callable[[Frame], Awaitable[bool]],
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
filter_system_frames: bool = False,
|
||||
):
|
||||
"""Initialize the function filter.
|
||||
|
||||
@@ -36,22 +37,32 @@ class FunctionFilter(FrameProcessor):
|
||||
frame should pass through, False otherwise.
|
||||
direction: The direction to apply filtering. Only frames moving in
|
||||
this direction will be filtered. Defaults to DOWNSTREAM.
|
||||
filter_system_frames: Whether to filter system frames. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self._filter = filter
|
||||
self._direction = direction
|
||||
self._filter_system_frames = filter_system_frames
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
# Ignore system frames, end frames and frames that are not following the
|
||||
# direction of this gate
|
||||
def _should_passthrough_frame(self, frame, direction):
|
||||
"""Check if a frame should pass through without filtering."""
|
||||
# Ignore system frames, end frames and frames that are not following the
|
||||
# direction of this gate
|
||||
return isinstance(frame, (SystemFrame, EndFrame)) or direction != self._direction
|
||||
# Always passthrough frames in the wrong direction
|
||||
if direction != self._direction:
|
||||
return True
|
||||
|
||||
# Always passthrough lifecycle frames
|
||||
if isinstance(frame, (StartFrame, EndFrame, CancelFrame)):
|
||||
return True
|
||||
|
||||
# If not filtering system frames, passthrough all other system frames
|
||||
if not self._filter_system_frames and isinstance(frame, SystemFrame):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process a frame through the filter.
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
"""Unit tests for ServiceSwitcher and related components."""
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -52,6 +54,13 @@ class MockFrameProcessor(FrameProcessor):
|
||||
self.frame_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySystemFrame(SystemFrame):
|
||||
"""A dummy system frame for testing purposes."""
|
||||
|
||||
text: str = ""
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
|
||||
@@ -140,14 +149,22 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
# Send some test frames
|
||||
frames_to_send = [
|
||||
TextFrame(text="Hello 1"),
|
||||
DummySystemFrame(text="System Message 1"),
|
||||
TextFrame(text="Hello 2"),
|
||||
DummySystemFrame(text="System Message 2"),
|
||||
TextFrame(text="Hello 3"),
|
||||
]
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[TextFrame, TextFrame, TextFrame],
|
||||
expected_down_frames=[
|
||||
DummySystemFrame,
|
||||
DummySystemFrame,
|
||||
TextFrame,
|
||||
TextFrame,
|
||||
TextFrame,
|
||||
],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
@@ -156,7 +173,13 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
|
||||
self.assertEqual(len(text_frames), 3)
|
||||
|
||||
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
|
||||
# Only service1 should have processed the system frames
|
||||
system_frames = [
|
||||
f for f in self.service1.processed_frames if isinstance(f, DummySystemFrame)
|
||||
]
|
||||
self.assertEqual(len(system_frames), 2)
|
||||
|
||||
# Check that other services don't receive text frames (they still get StartFrame/EndFrame)
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
@@ -166,10 +189,24 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(service2_text_frames), 0)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
# Check that other services don't receive dummy system frames (they still get StartFrame/EndFrame)
|
||||
service2_system_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, DummySystemFrame)
|
||||
]
|
||||
service3_system_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, DummySystemFrame)
|
||||
]
|
||||
self.assertEqual(len(service2_system_frames), 0)
|
||||
self.assertEqual(len(service3_system_frames), 0)
|
||||
|
||||
# Verify the actual text frames processed
|
||||
for i, frame in enumerate(text_frames):
|
||||
self.assertEqual(frame.text, f"Hello {i + 1}")
|
||||
|
||||
# Verify the actual system frames processed
|
||||
for i, frame in enumerate(system_frames):
|
||||
self.assertEqual(frame.text, f"System Message {i + 1}")
|
||||
|
||||
async def test_service_switching(self):
|
||||
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
Reference in New Issue
Block a user