Compare commits

..

1 Commits

Author SHA1 Message Date
James Hush
6bb3cb2b83 demo: DelayProcessor 2025-09-11 16:05:08 +08:00
101 changed files with 1940 additions and 5157 deletions

View File

@@ -9,164 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `on_before_disconnect` synchronous event to `DailyTransport` and
`LiveKitTransport`.
- It is now possible to register synchronous event handlers. By default, all
event handlers are executed in a separate task. However, in some cases we want
to guarantee order of execution, for example, executing something before
disconnecting a transport.
```python
self._register_event_handler("on_event_name", sync=True)
```
- Added support for global location in `GoogleVertexLLMService`. The service now
supports both regional locations (e.g., "us-east4") and the "global" location
for Vertex AI endpoints. When using "global" location, the service will use
`aiplatform.googleapis.com` as the API host instead of the regional format.
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
fired when the pipeline is done running. This can be the result of a
`StopFrame`, `CancelFrame` or `EndFrame`.
```python
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
...
```
### Changed
- Updated Silero VAD model to v6.
- Updated `livekit` to 1.0.13.
- `torch` and `torchaudio` are no longer required for running Smart Turn
locally. This avoids gigabytes of dependencies being installed.
- Updated `websockets` dependency to support version 15.0. Removed deprecated
usage of `ConnectionClosed.code` and `ConnectionClosed.reason` attributes in
`AWSTranscribeSTTService` for compatibility.
- Refactored `pyproject.toml` to reduce websockets dependency repetition using
self-referencing extras. All websockets-dependent services now reference a
shared `websockets-base` extra.
### Deprecated
- `GladiaSTTService`'s `confidence` arg is deprecated. `confidence` is no
longer needed to determine which transcription or translation frames to
emit.
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
instead.
### Fixed
- Fixed an issue where multiple handlers for an event would not run in parallel.
- Fixed `DailyTransport.sip_call_transfer()` to automatically use the session
ID from the `on_dialin_connected` event, when not explicitly provided. Now
supports cold transfers (from incoming dial-in calls) by automatically
tracking session IDs from connection events.
- Fixed a memory leak in `SmallWebRTCTransport`. In `aiortc`, when you receive
a `MediaStreamTrack` (audio or video), frames are produced asynchronously. If
the code never consumes these frames, they are queued in memory, causing a
memory leak.
- Fixed an issue in `AsyncAITTSService`, where `TTSTextFrames` were not being
pushed.
- Fixed an issue that would cause `push_interruption_task_frame_and_wait()` to
not wait if a previous interruption had already happened.
- Fixed a couple of bugs in `ServiceSwitcher`:
- Using multiple `ServiceSwitcher`s in a pipeline would result in an error.
- `ServiceSwitcherFrame`s (such as `ManuallySwitchServiceFrame`s) were having
an effect too early, essentially "jumping the queue" in terms of pipeline
frame ordering.
- Fixed a self-cancellation deadlock in `UserIdleProcessor` when returning
`False` from an idle callback. The task now terminates naturally instead of
attempting to cancel itself.
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
when a bot speaks and user input is blocked.
- Fixed a `FastAPIWebsocketTransport` and `SmallWebRTCTransport` issue where
`on_client_disconnected` would be triggered when the bot ends the
conversation. That is, `on_client_disconnected` should only be triggered when
the remote client actually disconnects.
- Fixed an issue in `HeyGenVideoService` where the `BotStartedSpeakingFrame`
was blocked from moving through the Pipeline.
## [0.0.85] - 2025-09-12
### Added
- `AzureSTTService` now pushes interim transcriptions.
- Added `voice_cloning_key` to `GoogleTTSService` to support custom cloned
voices.
- Added `speaking_rate` to `GoogleTTSService.InputParams` to control the
speaking rate.
- Added a `speed` arg to `OpenAITTSService` to control the speed of the voice
response.
- Added `FrameProcessor.push_interruption_task_frame_and_wait()`. Use this
method to programatically interrupt the bot from any part of the
pipeline. This guarantees that all the processors in the pipeline are
interrupted in order (from upstream to downstream). Internally, this works by
first pushing an `InterruptionTaskFrame` upstream until it reaches the
pipeline task. The pipeline task then generates an `InterruptionFrame`, which
flows downstream through all processors. Once the `InterruptionFrame` has
reaches the processor waiting for the interruption, the function returns and
execution continues after the call. Think of it as sending an upstream request
for interruption and waiting until the acknowledgment flows back downstream.
- Added new base `TaskFrame` (which is a system frame). This is the base class
for all task frames (`EndTaskFrame`, `CancelTaskFrame`, etc.) that are meant
to be pushed upstream to reach the pipeline task.
- Expanded support for universal `LLMContext` to the AWS Bedrock LLM service.
Using the universal `LLMContext` and associated `LLMContextAggregatorPair` is
a pre-requisite for using `LLMSwitcher` to switch between LLMs at runtime.
- Added new fields to the development runner's `parse_telephony_websocket`
method in support of providing dynamic data to a bot.
- Twilio: Added a new `body` parameter, which parses the websocket message
for `customParameters`. Provide data via the `Parameter` nouns in your
TwiML to use this feature.
- Telnyx & Exotel: Both providers make the `to` and `from` phone numbers
available in the websocket messages. You can now access these numbers as
`call_data["to"]` and `call_data["from"]`.
Note: Each telephony provider offers different features. Refer to the
corresponding example in `pipecat-examples` to see how to pass custom data
to your bot.
- Added `body` to the `WebsocketRunnerArguments` as an optional parameter.
Custom `body` information can be passed from the server into the bot file via
the `bot()` method using this new parameter.
- Added video streaming support to `LiveKitTransport`.
- Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide
access to OpenAI Realtime.
### Changed
- `pipeline.tests.utils.run_test()` now allows passing `PipelineParams` instead
of individual parameters.
### Removed
- Remove `VisionImageRawFrame` in favor of context frames (`LLMContextFrame` or
@@ -174,10 +21,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated
- `BotInterruptionFrame` is now deprecated, use `InterruptionTaskFrame` instead.
- `StartInterruptionFrame` is now deprected, use `InterruptionFrame` instead.
- Deprecate `VisionImageFrameAggregator` because `VisionImageRawFrame` has been
removed. See the `12*` examples for the new recommended replacement pattern.
@@ -190,9 +33,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed a `BaseOutputTransport` issue that caused incorrect detection of when
the bot stopped talking while using an audio mixer.
- Fixed a `LiveKitTransport` issue where RTVI messages were not properly
encoded.

View File

@@ -21,8 +21,6 @@
🧭 Looking to build structured conversations? Check out [Pipecat Flows](https://github.com/pipecat-ai/pipecat-flows) for managing complex conversational states and transitions.
🔍 Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
## 🧠 Why Pipecat?
- **Voice-first**: Integrates speech recognition, text-to-speech, and conversation handling
@@ -155,11 +153,7 @@ You can get started with Pipecat running on your local machine, then move your a
2. Install development and testing dependencies:
```bash
uv sync --group dev --all-extras \
--no-extra gstreamer \
--no-extra krisp \
--no-extra local \
--no-extra ultravox # (ultravox not fully supported on macOS)
uv sync --group dev --all-extras --no-extra gstreamer --no-extra krisp --no-extra local
```
3. Install the git pre-commit hooks:
@@ -168,6 +162,23 @@ You can get started with Pipecat running on your local machine, then move your a
uv run pre-commit install
```
### Python 3.13+ Compatibility
Some features require PyTorch, which doesn't yet support Python 3.13+. Install using:
```bash
uv sync --group dev --all-extras \
--no-extra gstreamer \
--no-extra krisp \
--no-extra local \
--no-extra local-smart-turn \
--no-extra mlx-whisper \
--no-extra moondream \
--no-extra ultravox
```
> **Tip:** For full compatibility, use Python 3.12: `uv python pin 3.12`
> **Note**: Some extras (local, gstreamer) require system dependencies. See documentation if you encounter build errors.
### Running tests

View File

@@ -11,7 +11,7 @@ import sys
from dotenv import load_dotenv
from loguru import logger
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.frames.frames import TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
@@ -50,7 +50,7 @@ async def main():
async def on_first_participant_joined(transport, participant_id):
await asyncio.sleep(1)
await task.queue_frame(
TTSSpeakFrame(
TextFrame(
"Hello there! How are you doing today? Would you like to talk about the weather?"
)
)

View File

@@ -14,7 +14,7 @@ from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
InterruptionFrame,
BotInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -115,7 +115,7 @@ async def main():
await task.queue_frames(
[
InterruptionFrame(),
BotInterruptionFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame(
user_id=participant_id,

View File

@@ -4,17 +4,19 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.frames.frames import Frame, LLMFullResponseEndFrame, LLMRunFrame, LLMTextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
@@ -26,6 +28,62 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
class DelayProcessor(FrameProcessor):
"""Custom processor that queues LLM text frames until response is complete.
This creates a more natural conversation flow by preventing the agent from
responding immediately after the user stops speaking. It queues all LLMTextFrames
until it sees an LLMFullResponseEndFrame, then waits for the specified delay
before releasing all queued frames at once.
"""
def __init__(self, *, delay_seconds: float = 1.0, **kwargs) -> None:
"""Initialize the DelayProcessor.
Args:
delay_seconds: Number of seconds to delay before releasing queued frames (default: 1.0)
"""
super().__init__(**kwargs)
self._delay_seconds = delay_seconds
self._queued_frames = []
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Process frames, queuing LLM text frames until response is complete.
Args:
frame: The frame to process
direction: Direction of the frame in the pipeline
"""
await super().process_frame(frame, direction)
if isinstance(frame, LLMTextFrame):
# Queue LLM text frames instead of pushing them immediately
logger.debug(f"Queuing LLMTextFrame: {frame.text}")
self._queued_frames.append((frame, direction))
elif isinstance(frame, LLMFullResponseEndFrame):
# When we see the end frame, wait for delay then push all queued frames
logger.debug(
f"LLM response complete, delaying {self._delay_seconds} seconds before releasing {len(self._queued_frames)} queued frames"
)
await asyncio.sleep(self._delay_seconds)
# Push all queued LLM text frames
for queued_frame, queued_direction in self._queued_frames:
logger.debug(f"Releasing queued LLMTextFrame: {queued_frame.text}")
await self.push_frame(queued_frame, queued_direction)
# Clear the queue
self._queued_frames.clear()
# Push the end frame
logger.debug("Pushing LLMFullResponseEndFrame")
await self.push_frame(frame, direction)
else:
# Push all other frames immediately
await self.push_frame(frame, direction)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -70,12 +128,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
# Create delay processor to add 1-second delay before agent responses
delay_processor = DelayProcessor(delay_seconds=1.0)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
delay_processor, # Add delay before TTS
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses

View File

@@ -36,6 +36,7 @@ load_dotenv(override=True)
audiobuffer = AudioBufferProcessor(
num_channels=2, # 1 for mono, 2 for stereo (user left, bot right)
enable_turn_audio=False, # Enable per-turn audio recording
user_continuous_stream=True, # User has continuous audio stream
)

View File

@@ -12,8 +12,8 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.frames.frames import (
InterruptionFrame,
LLMRunFrame,
StartInterruptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -97,7 +97,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
@stt.event_handler("on_speech_started")
async def on_speech_started(stt, *args, **kwargs):
await task.queue_frames([InterruptionFrame(), UserStartedSpeakingFrame()])
await task.queue_frames([StartInterruptionFrame(), UserStartedSpeakingFrame()])
@stt.event_handler("on_utterance_end")
async def on_utterance_end(stt, *args, **kwargs):

View File

@@ -16,10 +16,10 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMRunFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -181,7 +181,9 @@ class TranscriptionContextFixup(FrameProcessor):
if isinstance(frame, MagicDemoTranscriptionFrame):
self._transcript = frame.text
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, InterruptionFrame):
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(
frame, StartInterruptionFrame
):
self.swap_user_audio()
self.add_transcript_back_to_inference_output()
self._transcript = ""

View File

@@ -13,7 +13,6 @@ from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
TextFrame,
TTSSpeakFrame,
UserImageRawFrame,
@@ -22,7 +21,10 @@ from pipecat.frames.frames import (
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.aggregators.user_response import UserResponseAggregator
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.runner.types import RunnerArguments
@@ -71,14 +73,14 @@ class UserImageProcessor(FrameProcessor):
if isinstance(frame, UserImageRawFrame):
if frame.request and frame.request.context:
# Note: AWS Bedrock does not yet support the universal LLMContext
context = LLMContext()
context = OpenAILLMContext()
context.add_image_frame_message(
image=frame.image,
text=frame.request.context,
size=frame.size,
format=frame.format,
)
frame = LLMContextFrame(context)
frame = OpenAILLMContextFrame(context)
await self.push_frame(frame)
else:
await self.push_frame(frame, direction)
@@ -119,9 +121,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
aws = AWSBedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
# Note: usually, prefer providing latency="optimized" param.
# Here we can't because AWS Bedrock doesn't support it for Claude 3.7,
# which we need for image input.
params=AWSBedrockLLMService.InputParams(temperature=0.8),
)

View File

@@ -1,214 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import (
create_transport,
get_transport_client_id,
maybe_capture_participant_camera,
)
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
# Global variable to store the client ID
client_id = ""
async def get_weather(params: FunctionCallParams):
location = params.arguments["location"]
await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
async def get_image(params: FunctionCallParams):
question = params.arguments["question"]
logger.debug(f"Requesting image with user_id={client_id}, question={question}")
# Request the image frame
await params.llm.request_image_frame(
user_id=client_id,
function_name=params.function_name,
tool_call_id=params.tool_call_id,
text_content=question,
)
# Wait a short time for the frame to be processed
await asyncio.sleep(0.5)
# Return a result to complete the function call
await params.result_callback(
f"I've captured an image from your camera and I'm analyzing what you asked about: {question}"
)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
video_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
video_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = AWSBedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
# Note: usually, prefer providing latency="optimized" param.
# Here we can't because AWS Bedrock doesn't support it for Claude 3.7,
# which we need for image input.
params=AWSBedrockLLMService.InputParams(temperature=0.8),
)
llm.register_function("get_weather", get_weather)
llm.register_function("get_image", get_image)
weather_function = FunctionSchema(
name="get_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
required=["location"],
)
get_image_function = FunctionSchema(
name="get_image",
description="Get an image from the video stream.",
properties={
"question": {
"type": "string",
"description": "The question that the user is asking about the image.",
}
},
required=["question"],
)
tools = ToolsSchema(standard_tools=[weather_function, get_image_function])
system_prompt = """\
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
Your response will be turned into speech so use only simple words and punctuation.
You have access to two tools: get_weather and get_image.
You can respond to questions about the weather using the get_weather tool.
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
indicate you should use the get_image tool are:
- What do you see?
- What's in the video?
- Can you describe the video?
- Tell me about what you see.
- Tell me something interesting about what you see.
- What's happening in the video?
If you need to use a tool, simply use the tool. Do not tell the user the tool you are using. Be brief and concise.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Start the conversation by introducing yourself."},
]
context = LLMContext(messages, tools)
context_aggregator = LLMContextAggregatorPair(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # STT
context_aggregator.user(), # User speech to text
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses and tool context
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_camera(transport, client)
global client_id
client_id = get_transport_client_id(transport, client)
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -22,7 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai_realtime_beta import (
InputAudioNoiseReduction,
@@ -31,6 +31,7 @@ from pipecat.services.openai_realtime_beta import (
SemanticTurnDetection,
SessionProperties,
)
from pipecat.services.openai_realtime_beta.events import AudioConfiguration, AudioInput
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -113,14 +114,18 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
session_properties = SessionProperties(
input_audio_transcription=InputAudioTranscription(),
modalities=["text"],
# Set openai TurnDetection parameters. Not setting this at all will turn it
# on by default
turn_detection=SemanticTurnDetection(),
# Or set to False to disable openai turn detection and use transport VAD
# turn_detection=False,
input_audio_noise_reduction=InputAudioNoiseReduction(type="near_field"),
audio=AudioConfiguration(
input=AudioInput(
transcription=InputAudioTranscription(),
# Set openai TurnDetection parameters. Not setting this at all will turn it
# on by default
turn_detection=SemanticTurnDetection(),
# Or set to False to disable openai turn detection and use transport VAD
# turn_detection=False,
noise_reduction=InputAudioNoiseReduction(type="near_field"),
)
),
output_modalities=["text"],
# tools=tools,
instructions="""You are a helpful and friendly AI.

View File

@@ -18,9 +18,9 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InterruptionFrame,
LLMRunFrame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
TextFrame,
TranscriptionFrame,
@@ -144,7 +144,7 @@ class OutputGate(FrameProcessor):
await self._start()
if isinstance(frame, (EndFrame, CancelFrame)):
await self._stop()
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
self._frames_buffer = []
self.close_gate()
await self.push_frame(frame, direction)
@@ -232,7 +232,7 @@ class TurnDetectionLLM(Pipeline):
async def pass_only_llm_trigger_frames(frame):
return (
isinstance(frame, OpenAILLMContextFrame)
or isinstance(frame, InterruptionFrame)
or isinstance(frame, StartInterruptionFrame)
or isinstance(frame, FunctionCallInProgressFrame)
or isinstance(frame, FunctionCallResultFrame)
)

View File

@@ -18,9 +18,9 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InterruptionFrame,
LLMRunFrame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
TextFrame,
TranscriptionFrame,
@@ -347,7 +347,7 @@ class OutputGate(FrameProcessor):
await self._start()
if isinstance(frame, (EndFrame, CancelFrame)):
await self._stop()
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
self._frames_buffer = []
self.close_gate()
await self.push_frame(frame, direction)
@@ -426,7 +426,7 @@ class TurnDetectionLLM(Pipeline):
async def pass_only_llm_trigger_frames(frame):
return (
isinstance(frame, OpenAILLMContextFrame)
or isinstance(frame, InterruptionFrame)
or isinstance(frame, StartInterruptionFrame)
or isinstance(frame, FunctionCallInProgressFrame)
or isinstance(frame, FunctionCallResultFrame)
)

View File

@@ -20,10 +20,10 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InputAudioRawFrame,
InterruptionFrame,
LLMFullResponseStartFrame,
LLMRunFrame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
TextFrame,
TranscriptionFrame,
@@ -570,7 +570,7 @@ class OutputGate(FrameProcessor):
await self._start()
if isinstance(frame, (EndFrame, CancelFrame)):
await self._stop()
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
self._frames_buffer = []
self.close_gate()
await self.push_frame(frame, direction)

View File

@@ -15,8 +15,8 @@ from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EndFrame,
InterruptionFrame,
LLMRunFrame,
StartInterruptionFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
@@ -48,7 +48,7 @@ class CustomObserver(BaseObserver):
"""Observer to log interruptions and bot speaking events to the console.
Logs all frame instances of:
- InterruptionFrame
- StartInterruptionFrame
- BotStartedSpeakingFrame
- BotStoppedSpeakingFrame
@@ -69,7 +69,7 @@ class CustomObserver(BaseObserver):
# Create direction arrow
arrow = "" if direction == FrameDirection.DOWNSTREAM else ""
if isinstance(frame, InterruptionFrame) and isinstance(src, BaseOutputTransport):
if isinstance(frame, StartInterruptionFrame) and isinstance(src, BaseOutputTransport):
logger.info(f"⚡ INTERRUPTION START: {src} {arrow} {dst} at {time_sec:.2f}s")
elif isinstance(frame, BotStartedSpeakingFrame):
logger.info(f"🤖 BOT START SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s")

View File

@@ -11,7 +11,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
@@ -30,6 +30,23 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
# to the path where the smart-turn repo is cloned.
#
# Example setup:
#
# # Git LFS (Large File Storage)
# brew install git-lfs
# # Hugging Face uses LFS to store large model files, including .mlpackage
# git lfs install
# # Clone the repo with the smart_turn_classifier.mlpackage
# git clone https://huggingface.co/pipecat-ai/smart-turn-v2
#
# Then set the env variable:
# export LOCAL_SMART_TURN_MODEL_PATH=./smart-turn
# or add it to your .env file
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -38,19 +55,25 @@ transport_params = {
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
turn_analyzer=LocalSmartTurnAnalyzerV2(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
turn_analyzer=LocalSmartTurnAnalyzerV2(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
turn_analyzer=LocalSmartTurnAnalyzerV2(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),
}

View File

@@ -47,32 +47,32 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
aic = [ "aic-sdk~=1.0.1" ]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "pipecat-ai[websockets-base]" ]
asyncai = [ "pipecat-ai[websockets-base]" ]
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
assemblyai = [ "websockets>=13.1,<15.0" ]
asyncai = [ "websockets>=13.1,<15.0" ]
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
cerebras = []
deepseek = []
daily = [ "daily-python~=0.19.9" ]
deepgram = [ "deepgram-sdk~=4.7.0" ]
elevenlabs = [ "pipecat-ai[websockets-base]" ]
elevenlabs = [ "websockets>=13.1,<15.0" ]
fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
gladia = [ "pipecat-ai[websockets-base]" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "pipecat-ai[websockets-base]" ]
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
gladia = [ "websockets>=13.1,<15.0" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
heygen = [ "livekit>=0.22.0", "websockets>=13.1,<15.0" ]
inworld = []
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
koala = [ "pvkoala~=2.0.3" ]
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "pipecat-ai[websockets-base]" ]
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "websockets>=13.1,<15.0" ]
local = [ "pyaudio~=0.2.14" ]
mcp = [ "mcp[cli]~=1.9.4" ]
mem0 = [ "mem0ai~=0.1.94" ]
@@ -80,35 +80,33 @@ mistral = []
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
nim = []
neuphonic = [ "pipecat-ai[websockets-base]" ]
neuphonic = [ "websockets>=13.1,<15.0" ]
noisereduce = [ "noisereduce~=3.0.3" ]
openai = [ "pipecat-ai[websockets-base]" ]
openai = [ "websockets>=13.1,<15.0" ]
openpipe = [ "openpipe~=4.50.0" ]
openrouter = []
perplexity = []
playht = [ "pipecat-ai[websockets-base]" ]
playht = [ "websockets>=13.1,<15.0" ]
qwen = []
rime = [ "pipecat-ai[websockets-base]" ]
rime = [ "websockets>=13.1,<15.0" ]
riva = [ "nvidia-riva-client~=2.21.1" ]
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
sambanova = []
sarvam = [ "pipecat-ai[websockets-base]" ]
sarvam = [ "websockets>=13.1,<15.0" ]
sentry = [ "sentry-sdk~=2.23.1" ]
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1, <2" ]
remote-smart-turn = []
silero = [ "onnxruntime>=1.20.1, <2" ]
silero = [ "onnxruntime~=1.20.1" ]
simli = [ "simli-ai~=0.1.10"]
soniox = [ "pipecat-ai[websockets-base]" ]
soniox = [ "websockets>=13.1,<15.0" ]
soundfile = [ "soundfile~=0.13.0" ]
speechmatics = [ "speechmatics-rt>=0.4.0" ]
tavus=[]
together = []
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
webrtc = [ "aiortc~=1.13.0", "opencv-python~=4.11.0.86" ]
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.117.0" ]
websockets-base = [ "websockets>=13.1,<16.0" ]
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
whisper = [ "faster-whisper~=1.1.1" ]
[dependency-groups]
@@ -156,7 +154,6 @@ where = ["src"]
"src/pipecat/audio/dtmf/dtmf-star.wav",
]
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
[tool.pytest.ini_options]
addopts = "--verbose"

View File

@@ -135,25 +135,6 @@ TESTS_14 = [
("14r-function-calling-aws.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
("14v-function-calling-openai.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
("14w-function-calling-mistral.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
("14x-function-calling-universal-context.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
(
"14y-function-calling-google-universal-context.py",
PROMPT_WEATHER,
EVAL_WEATHER,
BOT_SPEAKS_FIRST,
),
(
"14z-function-calling-anthropic-universal-context.py",
PROMPT_WEATHER,
EVAL_WEATHER,
BOT_SPEAKS_FIRST,
),
(
"14aa-function-calling-aws-universal-context.py",
PROMPT_WEATHER,
EVAL_WEATHER,
BOT_SPEAKS_FIRST,
),
# Currently not working.
# ("14c-function-calling-together.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
# ("14l-function-calling-deepseek.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
@@ -167,7 +148,6 @@ TESTS_15 = [
TESTS_19 = [
("19-openai-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
("19a-azure-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
("19b-openai-realtime-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
("19b-openai-realtime-beta-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
]

View File

@@ -1,12 +0,0 @@
#!/bin/bash
PID=$1
while true; do
# Clear the screen
clear
# Print the header + RSS in GB
ps -p "$PID" -o pid,comm,rss | \
awk 'NR==1 {print $0, "rss_GB"} NR>1 {printf "%s %s %s %.2f\n", $1,$2,$3,$3/1024/1024}'
sleep 1
done

View File

@@ -16,12 +16,7 @@ from typing import Any, Dict, Generic, List, TypeVar
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMSpecificMessage,
NotGiven,
)
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
# Should be a TypedDict
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
@@ -43,16 +38,6 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
Subclasses must implement provider-specific conversion logic.
"""
@property
@abstractmethod
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
Returns:
The identifier string.
"""
pass
@abstractmethod
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
"""Get provider-specific LLM invocation parameters from a universal LLM context.
@@ -91,28 +76,6 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
"""
pass
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
"""Get messages from the LLM context, including standard and LLM-specific messages.
Args:
context: The LLM context containing messages.
Returns:
List of messages including standard and LLM-specific messages.
"""
return context.get_messages(self.id_for_llm_specific_messages)
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
"""Convert tools from standard format to provider format.

View File

@@ -9,7 +9,7 @@
import copy
import json
from dataclasses import dataclass
from typing import Any, Dict, List, TypedDict
from typing import Any, Dict, List, Optional, TypedDict
from anthropic import NOT_GIVEN, NotGiven
from anthropic.types.message_param import MessageParam
@@ -28,7 +28,10 @@ from pipecat.processors.aggregators.llm_context import (
class AnthropicLLMInvocationParams(TypedDict):
"""Context-based parameters for invoking Anthropic's LLM API."""
"""Context-based parameters for invoking Anthropic's LLM API.
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
"""
system: str | NotGiven
messages: List[MessageParam]
@@ -42,16 +45,13 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
to the specific format required by Anthropic's Claude models for function calling.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
return "anthropic"
def get_llm_invocation_params(
self, context: LLMContext, enable_prompt_caching: bool
) -> AnthropicLLMInvocationParams:
"""Get Anthropic-specific LLM invocation parameters from a universal LLM context.
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
Args:
context: The LLM context containing messages, tools, etc.
enable_prompt_caching: Whether prompt caching should be enabled.
@@ -59,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking Anthropic's LLM API.
"""
messages = self._from_universal_context_messages(self.get_messages(context))
messages = self._from_universal_context_messages(self._get_messages(context))
return {
"system": messages.system,
"messages": (
@@ -76,6 +76,8 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
Removes or truncates sensitive data like image content for safe logging.
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
Args:
context: The LLM context containing messages.
@@ -83,7 +85,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
List of messages in a format ready for logging about Anthropic.
"""
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self.get_messages(context)).messages
messages = self._from_universal_context_messages(self._get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -97,6 +99,9 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
messages_for_logging.append(msg)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("anthropic")
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""

View File

@@ -31,11 +31,6 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
specific function-calling format, enabling tool use with Nova Sonic models.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.

View File

@@ -6,33 +6,21 @@
"""AWS Bedrock LLM adapter for Pipecat."""
import base64
import copy
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, TypedDict
from loguru import logger
from typing import Any, Dict, List, TypedDict
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMContextToolChoice,
LLMSpecificMessage,
LLMStandardMessage,
)
from pipecat.processors.aggregators.llm_context import LLMContext
class AWSBedrockLLMInvocationParams(TypedDict):
"""Context-based parameters for invoking AWS Bedrock's LLM API."""
"""Context-based parameters for invoking AWS Bedrock's LLM API.
system: Optional[List[dict[str, Any]]] # [{"text": "system message"}]
messages: List[dict[str, Any]]
tools: List[dict[str, Any]]
tool_choice: LLMContextToolChoice
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
"""
pass
class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
@@ -42,244 +30,33 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
into AWS Bedrock's expected tool format for function calling capabilities.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
return "aws"
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
Args:
context: The LLM context containing messages, tools, etc.
Returns:
Dictionary of parameters for invoking AWS Bedrock's LLM API.
"""
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,
"messages": messages.messages,
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools) or [],
# To avoid refactoring in AWSBedrockLLMService, we just pass through tool_choice.
# Eventually (when we don't have to maintain the non-LLMContext code path) we should do
# the conversion to Bedrock's expected format here rather than in AWSBedrockLLMService.
"tool_choice": context.tool_choice,
}
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
"""Get messages from a universal LLM context in a format ready for logging about AWS Bedrock.
Removes or truncates sensitive data like image content for safe logging.
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
Args:
context: The LLM context containing messages.
Returns:
List of messages in a format ready for logging about AWS Bedrock.
"""
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
for message in messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item.get("image"):
item["image"]["source"]["bytes"] = "..."
messages_for_logging.append(msg)
return messages_for_logging
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""
messages: List[dict[str, Any]]
system: Optional[str]
def _from_universal_context_messages(
self, universal_context_messages: List[LLMContextMessage]
) -> ConvertedMessages:
system = None
messages = []
# first, map messages using self._from_universal_context_message(m)
try:
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# See if we should pull the system message out of our messages list
if messages and messages[0]["role"] == "system":
system = messages[0]["content"]
messages.pop(0)
# Convert any subsequent "system"-role messages to "user"-role
# messages, as AWS Bedrock doesn't support system input messages.
for message in messages:
if message["role"] == "system":
message["role"] = "user"
# Merge consecutive messages with the same role.
i = 0
while i < len(messages) - 1:
current_message = messages[i]
next_message = messages[i + 1]
if current_message["role"] == next_message["role"]:
# Convert content to list of dictionaries if it's a string
if isinstance(current_message["content"], str):
current_message["content"] = [
{"type": "text", "text": current_message["content"]}
]
if isinstance(next_message["content"], str):
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
# Concatenate the content
current_message["content"].extend(next_message["content"])
# Remove the next message from the list
messages.pop(i + 1)
else:
i += 1
# Avoid empty content in messages
for message in messages:
if isinstance(message["content"], str) and message["content"] == "":
message["content"] = "(empty)"
elif isinstance(message["content"], list) and len(message["content"]) == 0:
message["content"] = [{"type": "text", "text": "(empty)"}]
return self.ConvertedMessages(messages=messages, system=system)
def _from_universal_context_message(self, message: LLMContextMessage) -> dict[str, Any]:
if isinstance(message, LLMSpecificMessage):
return copy.deepcopy(message.message)
return self._from_standard_message(message)
def _from_standard_message(self, message: LLMStandardMessage) -> dict[str, Any]:
"""Convert standard format message to AWS Bedrock format.
Handles conversion of text content, tool calls, and tool results.
Empty text content is converted to "(empty)".
Args:
message: Message in standard format.
Returns:
Message in AWS Bedrock format.
Examples:
Standard format input::
{
"role": "assistant",
"tool_calls": [
{
"id": "123",
"function": {"name": "search", "arguments": '{"q": "test"}'}
}
]
}
AWS Bedrock format output::
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "123",
"name": "search",
"input": {"q": "test"}
}
}
]
}
"""
message = copy.deepcopy(message)
if message["role"] == "tool":
# Try to parse the content as JSON if it looks like JSON
try:
if message["content"].strip().startswith("{") and message[
"content"
].strip().endswith("}"):
content_json = json.loads(message["content"])
tool_result_content = [{"json": content_json}]
else:
tool_result_content = [{"text": message["content"]}]
except:
tool_result_content = [{"text": message["content"]}]
return {
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": message["tool_call_id"],
"content": tool_result_content,
},
},
],
}
if message.get("tool_calls"):
tc = message["tool_calls"]
ret = {"role": "assistant", "content": []}
for tool_call in tc:
function = tool_call["function"]
arguments = json.loads(function["arguments"])
new_tool_use = {
"toolUse": {
"toolUseId": tool_call["id"],
"name": function["name"],
"input": arguments,
}
}
ret["content"].append(new_tool_use)
return ret
# Handle text content
content = message.get("content")
if isinstance(content, str):
if content == "":
return {"role": message["role"], "content": [{"text": "(empty)"}]}
else:
return {"role": message["role"], "content": [{"text": content}]}
elif isinstance(content, list):
new_content = []
for item in content:
# fix empty text
if item.get("type", "") == "text":
text_content = item["text"] if item["text"] != "" else "(empty)"
new_content.append({"text": text_content})
# handle image_url -> image conversion
if item["type"] == "image_url":
new_item = {
"image": {
"format": "jpeg",
"source": {
"bytes": base64.b64decode(item["image_url"]["url"].split(",")[1])
},
}
}
new_content.append(new_item)
# In the case where there's a single image in the list (like what
# would result from a UserImageRawFrame), ensure that the image
# comes before text
image_indices = [i for i, item in enumerate(new_content) if "image" in item]
text_indices = [i for i, item in enumerate(new_content) if "text" in item]
if len(image_indices) == 1 and text_indices:
img_idx = image_indices[0]
first_txt_idx = text_indices[0]
if img_idx > first_txt_idx:
# Move image before the first text
image_item = new_content.pop(img_idx)
new_content.insert(first_txt_idx, image_item)
return {"role": message["role"], "content": new_content}
return message
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
@staticmethod
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:

View File

@@ -54,11 +54,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for Google."""
return "google"
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
@@ -68,7 +63,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
Returns:
Dictionary of parameters for Gemini's API.
"""
messages = self._from_universal_context_messages(self.get_messages(context))
messages = self._from_universal_context_messages(self._get_messages(context))
return {
"system_instruction": messages.system_instruction,
"messages": messages.messages,
@@ -108,7 +103,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
List of messages in a format ready for logging about Gemini.
"""
# Get messages in Gemini's format
messages = self._from_universal_context_messages(self.get_messages(context)).messages
messages = self._from_universal_context_messages(self._get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -124,6 +119,9 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
messages_for_logging.append(obj)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("google")
@dataclass
class ConvertedMessages:
"""Container for Google-formatted messages converted from universal context."""

View File

@@ -24,7 +24,6 @@ from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMContextToolChoice,
LLMSpecificMessage,
NotGiven,
)
@@ -48,11 +47,6 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
return "openai"
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
@@ -63,7 +57,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
Dictionary of parameters for OpenAI's ChatCompletion API.
"""
return {
"messages": self._from_universal_context_messages(self.get_messages(context)),
"messages": self._from_universal_context_messages(self._get_messages(context)),
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools),
"tool_choice": context.tool_choice,
@@ -97,7 +91,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
List of messages in a format ready for logging about OpenAI.
"""
msgs = []
for message in self.get_messages(context):
for message in self._get_messages(context):
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
@@ -110,18 +104,14 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
msgs.append(msg)
return msgs
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("openai")
def _from_universal_context_messages(
self, messages: List[LLMContextMessage]
) -> List[ChatCompletionMessageParam]:
result = []
for message in messages:
if isinstance(message, LLMSpecificMessage):
# Extract the actual message content from LLMSpecificMessage
result.append(message.message)
else:
# Standard message, pass through unchanged
result.append(message)
return result
# Just a pass-through: messages are already the right type
return messages
def _from_standard_tool_choice(
self, tool_choice: LLMContextToolChoice | NotGiven

View File

@@ -30,11 +30,6 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
OpenAI's Realtime API for function calling capabilities.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.

View File

@@ -1,124 +0,0 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Local turn analyzer for on-device ML inference using the smart-turn-v3 model.
This module provides a smart turn analyzer that uses an ONNX model for
local end-of-turn detection without requiring network connectivity.
"""
from typing import Any, Dict, Optional
import numpy as np
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
try:
import onnxruntime as ort
from transformers import WhisperFeatureExtractor
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use LocalSmartTurnAnalyzerV3, you need to `pip install pipecat-ai[local-smart-turn-v3]`."
)
raise Exception(f"Missing module: {e}")
class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
"""Local turn analyzer using the smart-turn-v3 ONNX model.
Provides end-of-turn detection using locally-stored ONNX model,
enabling offline operation without network dependencies.
"""
def __init__(self, *, smart_turn_model_path: Optional[str] = None, **kwargs):
"""Initialize the local ONNX smart-turn-v3 analyzer.
Args:
smart_turn_model_path: Path to the ONNX model file. If this is not
set, the bundled smart-turn-v3.0 model will be used.
**kwargs: Additional arguments passed to BaseSmartTurn.
"""
super().__init__(**kwargs)
logger.debug("Loading Local Smart Turn v3 model...")
if not smart_turn_model_path:
# Load bundled model
model_name = "smart-turn-v3.0.onnx"
package_path = "pipecat.audio.turn.smart_turn.data"
try:
import importlib_resources as impresources
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
except BaseException:
from importlib import resources as impresources
try:
with impresources.path(package_path, model_name) as f:
smart_turn_model_path = f
except BaseException:
smart_turn_model_path = str(
impresources.files(package_path).joinpath(model_name)
)
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.inter_op_num_threads = 1
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
logger.debug("Loaded Local Smart Turn v3")
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
"""Predict end-of-turn using local ONNX model."""
def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000):
"""Truncate audio to last n seconds or pad with zeros to meet n seconds."""
max_samples = n_seconds * sample_rate
if len(audio_array) > max_samples:
return audio_array[-max_samples:]
elif len(audio_array) < max_samples:
# Pad with zeros at the beginning
padding = max_samples - len(audio_array)
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
return audio_array
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
# Process audio using Whisper's feature extractor
inputs = self._feature_extractor(
audio_array,
sampling_rate=16000,
return_tensors="np",
padding="max_length",
max_length=8 * 16000,
truncation=True,
do_normalize=True,
)
# Extract features and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
# Run ONNX inference
outputs = self._session.run(None, {"input_features": input_features})
# Extract probability (ONNX model returns sigmoid probabilities)
probability = outputs[0][0].item()
# Make prediction (1 for Complete, 0 for Incomplete)
prediction = 1 if probability > 0.5 else 0
return {
"prediction": prediction,
"probability": probability,
}

View File

@@ -21,6 +21,7 @@ from typing import List, Optional
from loguru import logger
from pipecat.frames.frames import (
BotInterruptionFrame,
EndFrame,
Frame,
LLMFullResponseEndFrame,
@@ -359,7 +360,7 @@ class ClassificationProcessor(FrameProcessor):
await self._voicemail_notifier.notify() # Clear buffered TTS frames
# Interrupt the current pipeline to stop any ongoing processing
await self.push_interruption_task_frame_and_wait()
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
# Set the voicemail event to trigger the voicemail handler
self._voicemail_event.clear()

View File

@@ -788,6 +788,43 @@ class FatalErrorFrame(ErrorFrame):
fatal: bool = field(default=True, init=False)
@dataclass
class EndTaskFrame(SystemFrame):
"""Frame to request graceful pipeline task closure.
This is used to notify the pipeline task that the pipeline should be
closed nicely (flushing all the queued frames) by pushing an EndFrame
downstream. This frame should be pushed upstream.
"""
pass
@dataclass
class CancelTaskFrame(SystemFrame):
"""Frame to request immediate pipeline task cancellation.
This is used to notify the pipeline task that the pipeline should be
stopped immediately by pushing a CancelFrame downstream. This frame
should be pushed upstream.
"""
pass
@dataclass
class StopTaskFrame(SystemFrame):
"""Frame to request pipeline task stop while keeping processors running.
This is used to notify the pipeline task that it should be stopped as
soon as possible (flushing all the queued frames) but that the pipeline
processors should be kept in a running state. This frame should be pushed
upstream.
"""
pass
@dataclass
class FrameProcessorPauseUrgentFrame(SystemFrame):
"""Frame to pause frame processing immediately.
@@ -820,7 +857,7 @@ class FrameProcessorResumeUrgentFrame(SystemFrame):
@dataclass
class InterruptionFrame(SystemFrame):
class StartInterruptionFrame(SystemFrame):
"""Frame indicating user started speaking (interruption detected).
Emitted by the BaseInputTransport to indicate that a user has started
@@ -832,34 +869,6 @@ class InterruptionFrame(SystemFrame):
pass
@dataclass
class StartInterruptionFrame(InterruptionFrame):
"""Frame indicating user started speaking (interruption detected).
.. deprecated:: 0.0.85
This frame is deprecated and will be removed in a future version.
Instead, use `InterruptionFrame`.
Emitted by the BaseInputTransport to indicate that a user has started
speaking (i.e. is interrupting). This is similar to
UserStartedSpeakingFrame except that it should be pushed concurrently
with other frames (so the order is not guaranteed).
"""
def __post_init__(self):
super().__post_init__()
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"StartInterruptionFrame is deprecated and will be removed in a future version. "
"Instead, use InterruptionFrame.",
DeprecationWarning,
stacklevel=2,
)
@dataclass
class UserStartedSpeakingFrame(SystemFrame):
"""Frame indicating user has started speaking.
@@ -935,6 +944,20 @@ class VADUserStoppedSpeakingFrame(SystemFrame):
pass
@dataclass
class BotInterruptionFrame(SystemFrame):
"""Frame indicating the bot should be interrupted.
Emitted when the bot should be interrupted. This will mainly cause the
same actions as if the user interrupted except that the
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
This frame should be pushed upstreams. It results in the BaseInputTransport
starting an interruption by pushing a StartInterruptionFrame downstream.
"""
pass
@dataclass
class BotStartedSpeakingFrame(SystemFrame):
"""Frame indicating the bot started speaking.
@@ -1266,103 +1289,6 @@ class SpeechControlParamsFrame(SystemFrame):
turn_params: Optional[SmartTurnParams] = None
#
# Task frames
#
@dataclass
class TaskFrame(SystemFrame):
"""Base frame for task frames.
This is a base class for frames that are meant to be sent and handled
upstream by the pipeline task. This might result in a corresponding frame
sent downstream (e.g. `InterruptionTaskFrame` / `InterruptionFrame` or
`EndTaskFrame` / `EndFrame`).
"""
pass
@dataclass
class EndTaskFrame(TaskFrame):
"""Frame to request graceful pipeline task closure.
This is used to notify the pipeline task that the pipeline should be
closed nicely (flushing all the queued frames) by pushing an EndFrame
downstream. This frame should be pushed upstream.
"""
pass
@dataclass
class CancelTaskFrame(TaskFrame):
"""Frame to request immediate pipeline task cancellation.
This is used to notify the pipeline task that the pipeline should be
stopped immediately by pushing a CancelFrame downstream. This frame
should be pushed upstream.
"""
pass
@dataclass
class StopTaskFrame(TaskFrame):
"""Frame to request pipeline task stop while keeping processors running.
This is used to notify the pipeline task that it should be stopped as
soon as possible (flushing all the queued frames) but that the pipeline
processors should be kept in a running state. This frame should be pushed
upstream.
"""
pass
@dataclass
class InterruptionTaskFrame(TaskFrame):
"""Frame indicating the bot should be interrupted.
Emitted when the bot should be interrupted. This will mainly cause the
same actions as if the user interrupted except that the
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
This frame should be pushed upstream.
"""
pass
@dataclass
class BotInterruptionFrame(InterruptionTaskFrame):
"""Frame indicating the bot should be interrupted.
.. deprecated:: 0.0.85
This frame is deprecated and will be removed in a future version.
Instead, use `InterruptionTaskFrame`.
Emitted when the bot should be interrupted. This will mainly cause the
same actions as if the user interrupted except that the
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
This frame should be pushed upstream.
"""
def __post_init__(self):
super().__post_init__()
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"BotInterruptionFrame is deprecated and will be removed in a future version. "
"Instead, use InterruptionTaskFrame.",
DeprecationWarning,
stacklevel=2,
)
#
# Control frames
#
@@ -1604,7 +1530,7 @@ class MixerEnableFrame(MixerControlFrame):
@dataclass
class ServiceSwitcherFrame(ControlFrame):
"""A base class for frames that affect ServiceSwitcher behavior."""
"""A base class for frames that control ServiceSwitcher behavior."""
pass

View File

@@ -54,7 +54,7 @@ class DebugLogObserver(BaseObserver):
Log frames with specific source/destination filters::
from pipecat.frames.frames import InterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.services.stt_service import STTService
@@ -62,8 +62,8 @@ class DebugLogObserver(BaseObserver):
observers=[
DebugLogObserver(
frame_types={
# Only log InterruptionFrame when source is BaseOutputTransport
InterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
# Only log StartInterruptionFrame when source is BaseOutputTransport
StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
# Only log UserStartedSpeakingFrame when destination is STTService
UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION),
# Log LLMTextFrame regardless of source or destination type

View File

@@ -6,15 +6,9 @@
"""Service switcher for switching between different services at runtime, with different switching strategies."""
from dataclasses import dataclass
from typing import Any, Generic, List, Optional, Type, TypeVar
from pipecat.frames.frames import (
ControlFrame,
Frame,
ManuallySwitchServiceFrame,
ServiceSwitcherFrame,
)
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, ServiceSwitcherFrame
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -28,6 +22,19 @@ class ServiceSwitcherStrategy:
self.services = services
self.active_service: Optional[FrameProcessor] = None
def is_active(self, service: FrameProcessor) -> bool:
"""Determine if the given service is the currently active one.
This method should be overridden by subclasses to implement specific logic.
Args:
service: The service to check.
Returns:
True if the given service is the active one, False otherwise.
"""
raise NotImplementedError("Subclasses must implement this method.")
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
"""Handle a frame that controls service switching.
@@ -53,6 +60,17 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
super().__init__(services)
self.active_service = services[0] if services else None
def is_active(self, service: FrameProcessor) -> bool:
"""Check if the given service is the currently active one.
Args:
service: The service to check.
Returns:
True if the given service is the active one, False otherwise.
"""
return service == self.active_service
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
"""Handle a frame that controls service switching.
@@ -61,21 +79,20 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
direction: The direction of the frame (upstream or downstream).
"""
if isinstance(frame, ManuallySwitchServiceFrame):
self._set_active_if_available(frame.service)
self._set_active(frame.service)
else:
raise ValueError(f"Unsupported frame type: {type(frame)}")
def _set_active_if_available(self, service: FrameProcessor):
"""Set the active service to the given one, if it is in the list of available services.
If it's not in the list, the request is ignored, as it may have been
intended for another ServiceSwitcher in the pipeline.
def _set_active(self, service: FrameProcessor):
"""Set the active service to the given one.
Args:
service: The service to set as active.
"""
if service in self.services:
self.active_service = service
else:
raise ValueError(f"Service {service} is not in the list of available services.")
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
@@ -91,43 +108,6 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
self.services = services
self.strategy = strategy
class ServiceSwitcherFilter(FunctionFilter):
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
def __init__(
self,
wrapped_service: FrameProcessor,
active_service: FrameProcessor,
direction: FrameDirection,
):
"""Initialize the service switcher filter with a strategy and direction."""
async def filter(_: Frame) -> bool:
return self._wrapped_service == self._active_service
super().__init__(filter, direction)
self._wrapped_service = wrapped_service
self._active_service = active_service
async def process_frame(self, frame, direction):
"""Process a frame through the filter, handling special internal filter-updating frames."""
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
self._active_service = frame.active_service
# Two ServiceSwitcherFilters "sandwich" a service. Push the
# frame only to update the other side of the sandwich, but
# otherwise don't let it leave the sandwich.
if direction == self._direction:
await self.push_frame(frame, direction)
return
await super().process_frame(frame, direction)
@dataclass
class ServiceSwitcherFilterFrame(ControlFrame):
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
active_service: FrameProcessor
@staticmethod
def _make_pipeline_definitions(
services: List[FrameProcessor], strategy: ServiceSwitcherStrategy
@@ -141,18 +121,14 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
def _make_pipeline_definition(
service: FrameProcessor, strategy: ServiceSwitcherStrategy
) -> Any:
async def filter(frame) -> bool:
_ = frame
return strategy.is_active(service)
return [
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
direction=FrameDirection.DOWNSTREAM,
),
FunctionFilter(filter, direction=FrameDirection.DOWNSTREAM),
service,
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
direction=FrameDirection.UPSTREAM,
),
FunctionFilter(filter, direction=FrameDirection.UPSTREAM),
]
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -166,7 +142,3 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
if isinstance(frame, ServiceSwitcherFrame):
self.strategy.handle_frame(frame, direction)
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
active_service=self.strategy.active_service
)
await super().process_frame(service_switcher_filter_frame, direction)

View File

@@ -32,8 +32,6 @@ from pipecat.frames.frames import (
Frame,
HeartbeatFrame,
InputAudioRawFrame,
InterruptionFrame,
InterruptionTaskFrame,
MetricsFrame,
StartFrame,
StopFrame,
@@ -115,28 +113,9 @@ class PipelineTask(BasePipelineTask):
- on_frame_reached_downstream: Called when downstream frames reach the sink
- on_idle_timeout: Called when pipeline is idle beyond timeout threshold
- on_pipeline_started: Called when pipeline starts with StartFrame
- on_pipeline_stopped: [deprecated] Called when pipeline stops with StopFrame
.. deprecated:: 0.0.86
Use `on_pipeline_finished` instead.
- on_pipeline_ended: [deprecated] Called when pipeline ends with EndFrame
.. deprecated:: 0.0.86
Use `on_pipeline_finished` instead.
- on_pipeline_cancelled: [deprecated] Called when pipeline is cancelled with CancelFrame
.. deprecated:: 0.0.86
Use `on_pipeline_finished` instead.
- on_pipeline_finished: Called after the pipeline has reached any terminal state.
This includes:
- StopFrame: pipeline was stopped (processors keep connections open)
- EndFrame: pipeline ended normally
- CancelFrame: pipeline was cancelled
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
the frame if they need to handle specific cases.
- on_pipeline_stopped: Called when pipeline stops with StopFrame
- on_pipeline_ended: Called when pipeline ends with EndFrame
- on_pipeline_cancelled: Called when pipeline is cancelled
Example::
@@ -147,10 +126,6 @@ class PipelineTask(BasePipelineTask):
@task.event_handler("on_idle_timeout")
async def on_pipeline_idle_timeout(task):
...
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task, frame):
...
"""
def __init__(
@@ -287,7 +262,6 @@ class PipelineTask(BasePipelineTask):
self._register_event_handler("on_pipeline_stopped")
self._register_event_handler("on_pipeline_ended")
self._register_event_handler("on_pipeline_cancelled")
self._register_event_handler("on_pipeline_finished")
@property
def params(self) -> PipelineParams:
@@ -316,27 +290,6 @@ class PipelineTask(BasePipelineTask):
"""
return self._turn_trace_observer
def event_handler(self, event_name: str):
"""Decorator for registering event handlers.
Args:
event_name: The name of the event to handle.
Returns:
The decorator function that registers the handler.
"""
if event_name in ["on_pipeline_stopped", "on_pipeline_ended", "on_pipeline_cancelled"]:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
f"Event '{event_name}' is deprecated, use 'on_pipeline_finished' instead.",
DeprecationWarning,
)
return super().event_handler(event_name)
def add_observer(self, observer: BaseObserver):
"""Add an observer to monitor pipeline execution.
@@ -579,7 +532,6 @@ class PipelineTask(BasePipelineTask):
)
finally:
await self._call_event_handler("on_pipeline_cancelled", frame)
await self._call_event_handler("on_pipeline_finished", frame)
logger.debug(f"{self}: Closing. Waiting for {frame} to reach the end of the pipeline...")
@@ -675,23 +627,13 @@ class PipelineTask(BasePipelineTask):
if isinstance(frame, EndTaskFrame):
# Tell the task we should end nicely.
logger.debug(f"{self}: received end task frame {frame}")
await self.queue_frame(EndFrame())
elif isinstance(frame, CancelTaskFrame):
# Tell the task we should end right away.
logger.debug(f"{self}: received cancel task frame {frame}")
await self.queue_frame(CancelFrame())
elif isinstance(frame, StopTaskFrame):
# Tell the task we should stop nicely.
logger.debug(f"{self}: received stop task frame {frame}")
await self.queue_frame(StopFrame())
elif isinstance(frame, InterruptionTaskFrame):
# Tell the task we should interrupt the pipeline. Note that we are
# bypassing the push queue and directly queue into the
# pipeline. This is in case the push task is blocked waiting for a
# pipeline-ending frame to finish traversing the pipeline.
logger.debug(f"{self}: received interruption task frame {frame}")
await self._pipeline.queue_frame(InterruptionFrame())
elif isinstance(frame, ErrorFrame):
if frame.fatal:
logger.error(f"A fatal error occurred: {frame}")
@@ -700,7 +642,7 @@ class PipelineTask(BasePipelineTask):
# Tell the task we should stop.
await self.queue_frame(StopTaskFrame())
else:
logger.warning(f"{self}: Something went wrong: {frame}")
logger.warning(f"Something went wrong: {frame}")
async def _sink_push_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames coming downstream from the pipeline.
@@ -727,11 +669,9 @@ class PipelineTask(BasePipelineTask):
self._pipeline_start_event.set()
elif isinstance(frame, EndFrame):
await self._call_event_handler("on_pipeline_ended", frame)
await self._call_event_handler("on_pipeline_finished", frame)
self._pipeline_end_event.set()
elif isinstance(frame, StopFrame):
await self._call_event_handler("on_pipeline_stopped", frame)
await self._call_event_handler("on_pipeline_finished", frame)
self._pipeline_end_event.set()
elif isinstance(frame, CancelFrame):
self._pipeline_end_event.set()

View File

@@ -16,6 +16,7 @@ from typing import Optional
from pipecat.audio.dtmf.types import KeypadEntry
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EndFrame,
Frame,
@@ -23,7 +24,7 @@ from pipecat.frames.frames import (
StartFrame,
TranscriptionFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.utils.time import time_now_iso8601
@@ -104,7 +105,7 @@ class DTMFAggregator(FrameProcessor):
# For first digit, schedule interruption.
if is_first_digit:
await self.push_interruption_task_frame_and_wait()
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
# Check for immediate flush conditions
if frame.button == self._termination_digit:

View File

@@ -22,6 +22,7 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -35,7 +36,6 @@ from pipecat.frames.frames import (
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
@@ -48,6 +48,7 @@ from pipecat.frames.frames import (
OpenAILLMContextAssistantTimestampFrame,
SpeechControlParamsFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserImageRawFrame,
@@ -137,7 +138,7 @@ class LLMFullResponseAggregator(FrameProcessor):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._call_event_handler("on_completion", self._aggregation, False)
self._aggregation = ""
self._started = False
@@ -531,9 +532,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
if should_interrupt:
logger.debug(
"Interruption conditions met - pushing interruption and aggregation"
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
)
await self.push_interruption_task_frame_and_wait()
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
await self._process_aggregation()
else:
logger.debug("Interruption conditions not met - not pushing aggregation")
@@ -837,7 +838,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
@@ -903,7 +904,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: InterruptionFrame):
async def _handle_interruptions(self, frame: StartInterruptionFrame):
await self.push_aggregation()
self._started = 0
await self.reset()

View File

@@ -13,6 +13,7 @@ LLM processing, and text-to-speech components in conversational AI pipelines.
import asyncio
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Set
from loguru import logger
@@ -22,6 +23,7 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -35,7 +37,6 @@ from pipecat.frames.frames import (
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMContextAssistantTimestampFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
@@ -47,6 +48,7 @@ from pipecat.frames.frames import (
LLMSetToolsFrame,
SpeechControlParamsFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserImageRawFrame,
@@ -309,9 +311,9 @@ class LLMUserAggregator(LLMContextAggregator):
if should_interrupt:
logger.debug(
"Interruption conditions met - pushing interruption and aggregation"
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
)
await self.push_interruption_task_frame_and_wait()
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
await self._process_aggregation()
else:
logger.debug("Interruption conditions not met - not pushing aggregation")
@@ -577,7 +579,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
@@ -643,7 +645,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: InterruptionFrame):
async def _handle_interruptions(self, frame: StartInterruptionFrame):
await self._push_aggregation()
self._started = 0
await self.reset()

View File

@@ -137,12 +137,12 @@ class AudioBufferProcessor(FrameProcessor):
return self._num_channels
def has_audio(self) -> bool:
"""Check if either user or bot audio buffers contain data.
"""Check if both user and bot audio buffers contain data.
Returns:
True if either buffer contains audio data.
True if both buffers contain audio data.
"""
return self._buffer_has_audio(self._user_audio_buffer) or self._buffer_has_audio(
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
self._bot_audio_buffer
)

View File

@@ -25,8 +25,8 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
STTMuteFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -204,7 +204,7 @@ class STTMuteFilter(FrameProcessor):
if isinstance(
frame,
(
InterruptionFrame,
StartInterruptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
UserStartedSpeakingFrame,

View File

@@ -28,9 +28,8 @@ from pipecat.frames.frames import (
FrameProcessorPauseUrgentFrame,
FrameProcessorResumeFrame,
FrameProcessorResumeUrgentFrame,
InterruptionFrame,
InterruptionTaskFrame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
@@ -220,14 +219,6 @@ class FrameProcessor(BaseObject):
self.__process_event: Optional[asyncio.Event] = None
self.__process_frame_task: Optional[asyncio.Task] = None
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
# Then we wait for the corresponding `InterruptionFrame` to travel from
# the start of the pipeline back to the processor that sent the
# `InterruptionTaskFrame`. This wait is handled using the following
# event.
self._wait_for_interruption = False
self._wait_interruption_event = asyncio.Event()
@property
def id(self) -> int:
"""Get the unique identifier for this processor.
@@ -551,14 +542,6 @@ class FrameProcessor(BaseObject):
if self._cancelling:
return
# If we are waiting for an interruption we will bypass all queued system
# frames and we will process the frame right away. This is because a
# previous system frame might be waiting for the interruption frame and
# it's blocking the input task.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
await self.__process_frame(frame, direction, callback)
return
if self._enable_direct_mode:
await self.__process_frame(frame, direction, callback)
else:
@@ -568,17 +551,11 @@ class FrameProcessor(BaseObject):
"""Pause processing of queued frames."""
logger.trace(f"{self}: pausing frame processing")
self.__should_block_frames = True
# We should also unset the process event here, in case it was set immediately after an interruption
if self.__process_event:
self.__process_event.clear()
async def pause_processing_system_frames(self):
"""Pause processing of queued system frames."""
logger.trace(f"{self}: pausing system frame processing")
self.__should_block_system_frames = True
# We should also unset the input event here, in case it was set immediately after an interruption
if self.__input_event:
self.__input_event.clear()
async def resume_processing_frames(self):
"""Resume processing of queued frames."""
@@ -611,7 +588,7 @@ class FrameProcessor(BaseObject):
if isinstance(frame, StartFrame):
await self.__start(frame)
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
await self._start_interruption()
await self.stop_all_metrics()
elif isinstance(frame, CancelFrame):
@@ -643,34 +620,6 @@ class FrameProcessor(BaseObject):
await self.__internal_push_frame(frame, direction)
# If we are waiting for an interruption and we get an interruption, then
# we can unblock `push_interruption_task_frame_and_wait()`.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
self._wait_interruption_event.set()
async def push_interruption_task_frame_and_wait(self):
"""Push an interruption task frame upstream and wait for the interruption.
This function sends an `InterruptionTaskFrame` upstream to the pipeline
task and waits to receive the corresponding `InterruptionFrame`. When
the function finishes it is guaranteed that the `InterruptionFrame` has
been pushed downstream.
"""
self._wait_for_interruption = True
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
# Wait for an `InterruptionFrame` to come to this processor and be
# pushed. Take a look at `push_frame()` to see how we first push the
# `InterruptionFrame` and then we set the event in order to maintain
# frame ordering.
await self._wait_interruption_event.wait()
# Clean the event.
self._wait_interruption_event.clear()
self._wait_for_interruption = False
async def __start(self, frame: StartFrame):
"""Handle the start frame to initialize processor state.
@@ -720,22 +669,20 @@ class FrameProcessor(BaseObject):
async def _start_interruption(self):
"""Start handling an interruption by cancelling current tasks."""
try:
if self._wait_for_interruption:
# If we get here we know the process task was just waiting for
# an interruption (push_interruption_task_frame_and_wait()), so
# we can't cancel the task because it might still need to do
# more things (e.g. pushing a frame after the
# interruption). Instead we just drain the queue because this is
# an interruption.
self.__reset_process_task()
else:
# Cancel and re-create the process task including the queue.
await self.__cancel_process_task()
self.__create_process_task()
# Cancel the process task. This will stop processing queued frames.
await self.__cancel_process_task()
except Exception as e:
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
await self.push_error(ErrorFrame(str(e)))
# Create a new process queue and task.
self.__create_process_task()
async def _stop_interruption(self):
"""Stop handling an interruption."""
# Nothing to do right now.
pass
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
"""Internal method to push frames to adjacent processors.
@@ -817,17 +764,6 @@ class FrameProcessor(BaseObject):
self.__process_queue = asyncio.Queue()
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
def __reset_process_task(self):
"""Reset non-system frame processing task."""
if self._enable_direct_mode:
return
self.__should_block_frames = False
self.__process_event = asyncio.Event()
while not self.__process_queue.empty():
self.__process_queue.get_nowait()
self.__process_queue.task_done()
async def __cancel_process_task(self):
"""Cancel the non-system frame processing task."""
if self.__process_frame_task:

View File

@@ -30,6 +30,7 @@ from loguru import logger
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -1205,7 +1206,7 @@ class RTVIProcessor(FrameProcessor):
async def interrupt_bot(self):
"""Send a bot interruption frame upstream."""
await self.push_interruption_task_frame_and_wait()
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
async def send_server_message(self, data: Any):
"""Send a server message to the client."""

View File

@@ -19,7 +19,7 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterruptionFrame,
StartInterruptionFrame,
TranscriptionFrame,
TranscriptionMessage,
TranscriptionUpdateFrame,
@@ -86,7 +86,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
transcript messages. Utterances are completed when:
- The bot stops speaking (BotStoppedSpeakingFrame)
- The bot is interrupted (InterruptionFrame)
- The bot is interrupted (StartInterruptionFrame)
- The pipeline ends (EndFrame)
"""
@@ -185,7 +185,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
- TTSTextFrame: Aggregates text for current utterance
- BotStoppedSpeakingFrame: Completes current utterance
- InterruptionFrame: Completes current utterance due to interruption
- StartInterruptionFrame: Completes current utterance due to interruption
- EndFrame: Completes current utterance at pipeline end
- CancelFrame: Completes current utterance due to cancellation
@@ -195,7 +195,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
"""
await super().process_frame(frame, direction)
if isinstance(frame, (InterruptionFrame, CancelFrame)):
if isinstance(frame, (StartInterruptionFrame, CancelFrame)):
# Push frame first otherwise our emitted transcription update frame
# might get cleaned up.
await self.push_frame(frame, direction)

View File

@@ -17,6 +17,7 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
StartFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -184,13 +185,15 @@ class UserIdleProcessor(FrameProcessor):
Runs in a loop until cancelled or callback indicates completion.
"""
running = True
while running:
while True:
try:
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
except asyncio.TimeoutError:
if not self._interrupted:
self._retry_count += 1
running = await self._callback(self, self._retry_count)
should_continue = await self._callback(self, self._retry_count)
if not should_continue:
await self._stop()
break
finally:
self._idle_event.clear()

View File

@@ -70,6 +70,7 @@ import asyncio
import os
import sys
from contextlib import asynccontextmanager
from typing import Dict
from loguru import logger
@@ -182,14 +183,13 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
from pipecat.transports.smallwebrtc.request_handler import (
SmallWebRTCRequest,
SmallWebRTCRequestHandler,
)
except ImportError as e:
logger.error(f"WebRTC transport dependencies not installed: {e}")
return
# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}
# Mount the frontend
app.mount("/client", SmallWebRTCPrebuiltUI)
@@ -198,33 +198,51 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
"""Redirect root requests to client interface."""
return RedirectResponse(url="/client/")
# Initialize the SmallWebRTC request handler
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
esp32_mode=esp32_mode, host=host
)
@app.post("/api/offer")
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
async def offer(request: dict, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests and manage peer connections."""
pc_id = request.get("pc_id")
if pc_id and pc_id in pcs_map:
pipecat_connection = pcs_map[pc_id]
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request["sdp"],
type=request["type"],
restart_pc=request.get("restart_pc", False),
)
else:
pipecat_connection = SmallWebRTCConnection()
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
"""Handle WebRTC connection closure and cleanup."""
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
pcs_map.pop(webrtc_connection.pc_id, None)
# Prepare runner arguments with the callback to run your bot
async def webrtc_connection_callback(connection):
bot_module = _get_bot_module()
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
background_tasks.add_task(bot_module.bot, runner_args)
# Delegate handling to SmallWebRTCRequestHandler
answer = await small_webrtc_handler.handle_web_request(
request=request,
webrtc_connection_callback=webrtc_connection_callback,
)
answer = pipecat_connection.get_answer()
# Apply ESP32 SDP munging if enabled
if esp32_mode and host != "localhost":
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], host)
pcs_map[answer["pc_id"]] = pipecat_connection
return answer
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage FastAPI application lifecycle and cleanup connections."""
yield
await small_webrtc_handler.close()
coros = [pc.disconnect() for pc in pcs_map.values()]
await asyncio.gather(*coros)
pcs_map.clear()
app.router.lifespan_context = lifespan

View File

@@ -51,11 +51,9 @@ class WebSocketRunnerArguments(RunnerArguments):
Parameters:
websocket: WebSocket connection for audio streaming
body: Additional request data
"""
websocket: WebSocket
body: Optional[Any] = field(default_factory=dict)
@dataclass

View File

@@ -99,35 +99,16 @@ async def parse_telephony_websocket(websocket: WebSocket):
tuple: (transport_type: str, call_data: dict)
call_data contains provider-specific fields:
- Twilio: {
"stream_id": str,
"call_id": str,
"body": dict
}
- Telnyx: {
"stream_id": str,
"call_control_id": str,
"outbound_encoding": str,
"from": str,
"to": str,
}
- Plivo: {
"stream_id": str,
"call_id": str,
}
- Exotel: {
"stream_id": str,
"call_id": str,
"account_sid": str,
"from": str,
"to": str,
}
- Twilio: {"stream_id": str, "call_id": str}
- Telnyx: {"stream_id": str, "call_control_id": str, "outbound_encoding": str}
- Plivo: {"stream_id": str, "call_id": str}
- Exotel: {"stream_id": str, "call_id": str, "account_sid": str}
Example usage::
transport_type, call_data = await parse_telephony_websocket(websocket)
if transport_type == "twilio":
user_id = call_data["body"]["user_id"]
if transport_type == "telnyx":
outbound_encoding = call_data["outbound_encoding"]
"""
# Read first two messages
start_data = websocket.iter_text()
@@ -170,12 +151,9 @@ async def parse_telephony_websocket(websocket: WebSocket):
# Extract provider-specific data
if transport_type == "twilio":
start_data = call_data_raw.get("start", {})
body_data = start_data.get("customParameters", {})
call_data = {
"stream_id": start_data.get("streamSid"),
"call_id": start_data.get("callSid"),
# All custom parameters
"body": body_data,
}
elif transport_type == "telnyx":
@@ -185,8 +163,6 @@ async def parse_telephony_websocket(websocket: WebSocket):
"outbound_encoding": call_data_raw.get("start", {})
.get("media_format", {})
.get("encoding"),
"from": call_data_raw.get("start", {}).get("from", ""),
"to": call_data_raw.get("start", {}).get("to", ""),
}
elif transport_type == "plivo":
@@ -202,8 +178,6 @@ async def parse_telephony_websocket(websocket: WebSocket):
"stream_id": start_data.get("stream_sid"),
"call_id": start_data.get("call_sid"),
"account_sid": start_data.get("account_sid"),
"from": start_data.get("from", ""),
"to": start_data.get("to", ""),
}
else:

View File

@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InputDTMFFrame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
@@ -98,7 +98,7 @@ class ExotelFrameSerializer(FrameSerializer):
Returns:
Serialized data as string or bytes, or None if the frame isn't handled.
"""
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
answer = {"event": "clear", "streamSid": self._stream_sid}
return json.dumps(answer)
elif isinstance(frame, AudioRawFrame):

View File

@@ -22,8 +22,8 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InputDTMFFrame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
@@ -122,7 +122,7 @@ class PlivoFrameSerializer(FrameSerializer):
self._hangup_attempted = True
await self._hang_up_call()
return None
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
answer = {"event": "clearAudio", "streamId": self._stream_id}
return json.dumps(answer)
elif isinstance(frame, AudioRawFrame):

View File

@@ -29,8 +29,8 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InputDTMFFrame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
@@ -137,7 +137,7 @@ class TelnyxFrameSerializer(FrameSerializer):
self._hangup_attempted = True
await self._hang_up_call()
return None
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
answer = {"event": "clear"}
return json.dumps(answer)
elif isinstance(frame, AudioRawFrame):

View File

@@ -22,8 +22,8 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InputDTMFFrame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
@@ -122,7 +122,7 @@ class TwilioFrameSerializer(FrameSerializer):
self._hangup_attempted = True
await self._hang_up_call()
return None
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
answer = {"event": "clear", "streamSid": self._stream_sid}
return json.dumps(answer)
elif isinstance(frame, AudioRawFrame):

View File

@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -119,6 +119,7 @@ class AsyncAITTSService(InterruptibleTTSService):
"""
super().__init__(
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
pause_frame_processing=True,
push_stop_frames=True,
sample_rate=sample_rate,
@@ -274,7 +275,7 @@ class AsyncAITTSService(InterruptibleTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
async def _receive_messages(self):

View File

@@ -25,10 +25,7 @@ from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.bedrock_adapter import (
AWSBedrockLLMAdapter,
AWSBedrockLLMInvocationParams,
)
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.frames.frames import (
Frame,
FunctionCallCancelFrame,
@@ -811,55 +808,64 @@ class AWSBedrockLLMService(LLMService):
Returns:
The LLM's response as a string, or None if no response is generated.
"""
messages = []
system = []
if isinstance(context, LLMContext):
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
messages = params["messages"]
system = params["system"] # [{"text": "system message"}]
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) # [{"text": "system message"}]
try:
messages = []
system = []
if isinstance(context, LLMContext):
# Future code will be something like this:
# adapter = self.get_llm_adapter()
# params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
# messages = params["messages"]
# system = params["system_instruction"] # [{"text": "system message"}]
raise NotImplementedError(
"Universal LLMContext is not yet supported for AWS Bedrock."
)
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) # [{"text": "system message"}]
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
# Prepare request parameters
request_params = {
"modelId": model_id,
"messages": messages,
"inferenceConfig": {
"maxTokens": 8192,
"temperature": 0.7,
"topP": 0.9,
},
}
# Prepare request parameters
request_params = {
"modelId": model_id,
"messages": messages,
"inferenceConfig": {
"maxTokens": 8192,
"temperature": 0.7,
"topP": 0.9,
},
}
if system:
request_params["system"] = system
if system:
request_params["system"] = system
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
) as client:
# Call Bedrock without streaming
response = await client.converse(**request_params)
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
) as client:
# Call Bedrock without streaming
response = await client.converse(**request_params)
# Extract the response text
if (
"output" in response
and "message" in response["output"]
and "content" in response["output"]["message"]
):
content = response["output"]["message"]["content"]
if isinstance(content, list):
for item in content:
if item.get("text"):
return item["text"]
elif isinstance(content, str):
return content
# Extract the response text
if (
"output" in response
and "message" in response["output"]
and "content" in response["output"]["message"]
):
content = response["output"]["message"]["content"]
if isinstance(content, list):
for item in content:
if item.get("text"):
return item["text"]
elif isinstance(content, str):
return content
return None
except Exception as e:
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
return None
async def _create_converse_stream(self, client, request_params):
@@ -934,25 +940,8 @@ class AWSBedrockLLMService(LLMService):
}
}
def _get_llm_invocation_params(
self, context: OpenAILLMContext | LLMContext
) -> AWSBedrockLLMInvocationParams:
# Universal LLMContext
if isinstance(context, LLMContext):
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
params = adapter.get_llm_invocation_params(context)
return params
# AWS Bedrock-specific context
return AWSBedrockLLMInvocationParams(
system=getattr(context, "system", None),
messages=context.messages,
tools=context.tools or [],
tool_choice=context.tool_choice,
)
@traced_llm
async def _process_context(self, context: AWSBedrockLLMContext | LLMContext):
async def _process_context(self, context: AWSBedrockLLMContext):
# Usage tracking
prompt_tokens = 0
completion_tokens = 0
@@ -969,12 +958,6 @@ class AWSBedrockLLMService(LLMService):
await self.start_ttfb_metrics()
params_from_context = self._get_llm_invocation_params(context)
messages = params_from_context["messages"]
system = params_from_context["system"]
tools = params_from_context["tools"]
tool_choice = params_from_context["tool_choice"]
# Set up inference config
inference_config = {
"maxTokens": self._settings["max_tokens"],
@@ -985,18 +968,19 @@ class AWSBedrockLLMService(LLMService):
# Prepare request parameters
request_params = {
"modelId": self.model_name,
"messages": messages,
"messages": context.messages,
"inferenceConfig": inference_config,
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
}
# Add system message
system = getattr(context, "system", None)
if system:
request_params["system"] = system
# Check if messages contain tool use or tool result content blocks
has_tool_content = False
for message in messages:
for message in context.messages:
if isinstance(message.get("content"), list):
for content_item in message["content"]:
if "toolUse" in content_item or "toolResult" in content_item:
@@ -1006,6 +990,7 @@ class AWSBedrockLLMService(LLMService):
break
# Handle tools: use current tools, or no-op if tool content exists but no current tools
tools = context.tools or []
if has_tool_content and not tools:
tools = [self._create_no_op_tool()]
using_noop_tool = True
@@ -1014,15 +999,17 @@ class AWSBedrockLLMService(LLMService):
tool_config = {"tools": tools}
# Only add tool_choice if we have real tools (not just no-op)
if not using_noop_tool and tool_choice:
if tool_choice == "auto":
if not using_noop_tool and context.tool_choice:
if context.tool_choice == "auto":
tool_config["toolChoice"] = {"auto": {}}
elif tool_choice == "none":
elif context.tool_choice == "none":
# Skip adding toolChoice for "none"
pass
elif isinstance(tool_choice, dict) and "function" in tool_choice:
elif (
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
):
tool_config["toolChoice"] = {
"tool": {"name": tool_choice["function"]["name"]}
"tool": {"name": context.tool_choice["function"]["name"]}
}
request_params["toolConfig"] = tool_config
@@ -1032,16 +1019,9 @@ class AWSBedrockLLMService(LLMService):
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
# Log request params with messages redacted for logging
if isinstance(context, LLMContext):
adapter = self.get_llm_adapter()
context_type_for_logging = "universal"
messages_for_logging = adapter.get_messages_for_logging(context)
else:
context_type_for_logging = "LLM-specific"
messages_for_logging = context.get_messages_for_logging()
logger.debug(
f"{self}: Generating chat from {context_type_for_logging} context [{system}] | {messages_for_logging}"
)
log_params = dict(request_params)
log_params["messages"] = context.get_messages_for_logging()
logger.debug(f"Calling AWS Bedrock model with: {log_params}")
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
@@ -1149,7 +1129,7 @@ class AWSBedrockLLMService(LLMService):
if isinstance(frame, OpenAILLMContextFrame):
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
if isinstance(frame, LLMContextFrame):
context = frame.context
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
elif isinstance(frame, LLMMessagesFrame):
context = AWSBedrockLLMContext.from_messages(frame.messages)
elif isinstance(frame, LLMUpdateSettingsFrame):

View File

@@ -532,7 +532,9 @@ class AWSTranscribeSTTService(STTService):
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
logger.error(
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
)
break
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")

View File

@@ -247,14 +247,13 @@ class AWSNovaSonicLLMService(LLMService):
self._ready_to_send_context = False
self._handling_bot_stopped_speaking = False
self._triggering_assistant_response = False
self._assistant_response_trigger_audio: Optional[bytes] = (
None # Not cleared on _disconnect()
)
self._disconnecting = False
self._connected_time: Optional[float] = None
self._wants_connection = False
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
with wave.open(file_path.open("rb"), "rb") as wav_file:
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
#
# standard AIService frame handling
#
@@ -1100,13 +1099,20 @@ class AWSNovaSonicLLMService(LLMService):
self._triggering_assistant_response = True
# Read audio bytes, if we don't already have them cached
if not self._assistant_response_trigger_audio:
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
with wave.open(file_path.open("rb"), "rb") as wav_file:
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
# Send the trigger audio, if we're fully connected and set up
if self._connected_time:
if self._connected_time is not None:
await self._send_assistant_response_trigger()
async def _send_assistant_response_trigger(self):
if not self._connected_time:
# should never happen
if (
not self._assistant_response_trigger_audio or self._connected_time is None
): # should never happen
return
try:

View File

@@ -21,13 +21,13 @@ from pipecat.frames.frames import (
DataFrame,
Frame,
FunctionCallResultFrame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
)
@@ -306,7 +306,7 @@ class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
if isinstance(
frame,
(
InterruptionFrame,
StartInterruptionFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
TextFrame,

View File

@@ -19,7 +19,6 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
@@ -141,7 +140,6 @@ class AzureSTTService(STTService):
self._speech_recognizer = SpeechRecognizer(
speech_config=self._speech_config, audio_config=audio_config
)
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
self._speech_recognizer.start_continuous_recognition_async()
@@ -199,15 +197,3 @@ class AzureSTTService(STTService):
self._handle_transcription(event.result.text, True, language), self.get_event_loop()
)
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
def _on_handle_recognizing(self, event):
if event.result.reason == ResultReason.RecognizingSpeech and len(event.result.text) > 0:
language = getattr(event.result, "language", None) or self._settings.get("language")
frame = InterimTranscriptionFrame(
event.result.text,
self._user_id,
time_now_iso8601(),
language,
result=event,
)
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())

View File

@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -371,7 +371,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
if self._context_id:

View File

@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -460,7 +460,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("Reset", 0)])
@@ -549,7 +549,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
"""Handle interruption by closing the current context."""
await super()._handle_interruption(frame, direction)
@@ -558,7 +558,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
logger.trace(f"Closing context {self._context_id} due to interruption")
try:
# ElevenLabs requires that Pipecat manages the contexts and closes them
# when they're not longer in use. Since an InterruptionFrame is pushed
# when they're not longer in use. Since a StartInterruptionFrame is pushed
# every time the user speaks, we'll use this as a trigger to close the context
# and reset the state.
# Note: We do not need to call remove_audio_context here, as the context is
@@ -856,7 +856,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
if isinstance(frame, (StartInterruptionFrame, TTSStoppedFrame)):
# Reset timing on interruption or stop
self._reset_state()

View File

@@ -21,8 +21,8 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -259,7 +259,7 @@ class FishAudioTTSService(InterruptibleTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None

View File

@@ -33,7 +33,6 @@ from pipecat.frames.frames import (
InputAudioRawFrame,
InputImageRawFrame,
InputTextRawFrame,
InterruptionFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -42,6 +41,7 @@ from pipecat.frames.frames import (
LLMTextFrame,
LLMUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
@@ -752,7 +752,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
elif isinstance(frame, InputImageRawFrame):
await self._send_user_video(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
await self._handle_interruption()
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):

View File

@@ -13,7 +13,6 @@ supporting multiple languages, custom vocabulary, and various audio processing o
import asyncio
import base64
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Literal, Optional
import aiohttp
@@ -174,6 +173,8 @@ class _InputParamsDescriptor:
"""Descriptor for backward compatibility with deprecation warning."""
def __get__(self, obj, objtype=None):
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
@@ -207,7 +208,7 @@ class GladiaSTTService(STTService):
api_key: str,
region: Literal["us-west", "eu-west"] | None = None,
url: str = "https://api.gladia.io/v2/live",
confidence: Optional[float] = None,
confidence: float = 0.5,
sample_rate: Optional[int] = None,
model: str = "solaria-1",
params: Optional[GladiaInputParams] = None,
@@ -223,11 +224,6 @@ class GladiaSTTService(STTService):
region: Region used to process audio. eu-west or us-west. Defaults to eu-west.
url: Gladia API URL. Defaults to "https://api.gladia.io/v2/live".
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
.. deprecated:: 0.0.86
The 'confidence' parameter is deprecated and will be removed in a future version.
No confidence threshold is applied.
sample_rate: Audio sample rate in Hz. If None, uses service default.
model: Model to use for transcription. Defaults to "solaria-1".
params: Additional configuration parameters for Gladia service.
@@ -240,6 +236,7 @@ class GladiaSTTService(STTService):
params = params or GladiaInputParams()
# Warn about deprecated language parameter if it's used
if params.language is not None:
with warnings.catch_warnings():
warnings.simplefilter("always")
@@ -250,20 +247,11 @@ class GladiaSTTService(STTService):
stacklevel=2,
)
if confidence:
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"The 'confidence' parameter is deprecated and will be removed in a future version. "
"No confidence threshold is applied.",
DeprecationWarning,
stacklevel=2,
)
self._api_key = api_key
self._region = region
self._url = url
self.set_model_name(model)
self._confidence = confidence
self._params = params
self._websocket = None
self._receive_task = None
@@ -587,40 +575,43 @@ class GladiaSTTService(STTService):
elif content["type"] == "transcript":
utterance = content["data"]["utterance"]
confidence = utterance.get("confidence", 0)
language = utterance["language"]
transcript = utterance["text"]
is_final = content["data"]["is_final"]
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
if confidence >= self._confidence:
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
)
)
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
)
)
elif content["type"] == "translation":
translated_utterance = content["data"]["translated_utterance"]
original_language = content["data"]["original_language"]
translated_language = translated_utterance["language"]
confidence = translated_utterance.get("confidence", 0)
translation = translated_utterance["text"]
if translated_language != original_language:
if translated_language != original_language and confidence >= self._confidence:
await self.push_frame(
TranslationFrame(
translation, "", time_now_iso8601(), translated_language

View File

@@ -83,23 +83,14 @@ class GoogleVertexLLMService(OpenAILLMService):
self._api_key = self._get_api_token(credentials, credentials_path)
super().__init__(
api_key=self._api_key,
base_url=base_url,
model=model,
params=params,
**kwargs,
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
)
@staticmethod
def _get_base_url(params: InputParams) -> str:
"""Construct the base URL for Vertex AI API."""
# Determine the correct API host based on location
if params.location == "global":
api_host = "aiplatform.googleapis.com"
else:
api_host = f"{params.location}-aiplatform.googleapis.com"
return (
f"https://{api_host}/v1/"
f"https://{params.location}-aiplatform.googleapis.com/v1/"
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
)
@@ -127,14 +118,12 @@ class GoogleVertexLLMService(OpenAILLMService):
if credentials:
# Parse and load credentials from JSON string
creds = service_account.Credentials.from_service_account_info(
json.loads(credentials),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
elif credentials_path:
# Load credentials from JSON file
creds = service_account.Credentials.from_service_account_file(
credentials_path,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
else:
try:

View File

@@ -500,11 +500,9 @@ class GoogleTTSService(TTSService):
Parameters:
language: Language for synthesis. Defaults to English.
speaking_rate: The speaking rate, in the range [0.25, 4.0].
"""
language: Optional[Language] = Language.EN
speaking_rate: Optional[float] = None
def __init__(
self,
@@ -512,7 +510,6 @@ class GoogleTTSService(TTSService):
credentials: Optional[str] = None,
credentials_path: Optional[str] = None,
voice_id: str = "en-US-Chirp3-HD-Charon",
voice_cloning_key: Optional[str] = None,
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,
@@ -523,7 +520,6 @@ class GoogleTTSService(TTSService):
credentials: JSON string containing Google Cloud service account credentials.
credentials_path: Path to Google Cloud service account JSON file.
voice_id: Google TTS voice identifier (e.g., "en-US-Chirp3-HD-Charon").
voice_cloning_key: The voice cloning key for Chirp 3 custom voices.
sample_rate: Audio sample rate in Hz. If None, uses default.
params: Language configuration parameters.
**kwargs: Additional arguments passed to parent TTSService.
@@ -536,10 +532,8 @@ class GoogleTTSService(TTSService):
"language": self.language_to_service_language(params.language)
if params.language
else "en-US",
"speaking_rate": params.speaking_rate,
}
self.set_voice(voice_id)
self._voice_cloning_key = voice_cloning_key
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
credentials, credentials_path
)
@@ -606,24 +600,15 @@ class GoogleTTSService(TTSService):
try:
await self.start_ttfb_metrics()
if self._voice_cloning_key:
voice_clone_params = texttospeech_v1.VoiceCloneParams(
voice_cloning_key=self._voice_cloning_key
)
voice = texttospeech_v1.VoiceSelectionParams(
language_code=self._settings["language"], voice_clone=voice_clone_params
)
else:
voice = texttospeech_v1.VoiceSelectionParams(
language_code=self._settings["language"], name=self._voice_id
)
voice = texttospeech_v1.VoiceSelectionParams(
language_code=self._settings["language"], name=self._voice_id
)
streaming_config = texttospeech_v1.StreamingSynthesizeConfig(
voice=voice,
streaming_audio_config=texttospeech_v1.StreamingAudioConfig(
audio_encoding=texttospeech_v1.AudioEncoding.PCM,
sample_rate_hertz=self.sample_rate,
speaking_rate=self._settings["speaking_rate"],
),
)
config_request = texttospeech_v1.StreamingSynthesizeRequest(

View File

@@ -240,7 +240,6 @@ class HeyGenVideoService(AIService):
# As soon as we receive actual audio, the base output transport will create a
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
await self.stop_ttfb_metrics()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -36,15 +36,15 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
FunctionCallResultProperties,
FunctionCallsStartedFrame,
InterruptionFrame,
LLMConfigureOutputFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
StartInterruptionFrame,
UserImageRequestFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -195,17 +195,6 @@ class LLMService(AIService):
"""
return self._adapter
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return self.get_llm_adapter().create_llm_specific_message(message)
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
@@ -280,7 +269,7 @@ class LLMService(AIService):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
elif isinstance(frame, LLMConfigureOutputFrame):
self._skip_tts = frame.skip_tts
@@ -297,7 +286,7 @@ class LLMService(AIService):
await super().push_frame(frame, direction)
async def _handle_interruptions(self, _: InterruptionFrame):
async def _handle_interruptions(self, _: StartInterruptionFrame):
for function_name, entry in self._functions.items():
if entry.cancel_on_interruption:
await self._cancel_function_call(function_name)

View File

@@ -16,8 +16,8 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -180,7 +180,7 @@ class LmntTTSService(InterruptibleTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
async def _connect(self):

View File

@@ -7,7 +7,7 @@
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
import json
from typing import Any, Dict, List, TypeAlias
from typing import Any, Dict, List, Tuple
from loguru import logger
@@ -28,8 +28,6 @@ except ModuleNotFoundError as e:
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
raise Exception(f"Missing module: {e}")
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
class MCPClient(BaseObject):
"""Client for Model Context Protocol (MCP) servers.
@@ -44,7 +42,7 @@ class MCPClient(BaseObject):
def __init__(
self,
server_params: ServerParameters,
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
**kwargs,
):
"""Initialize the MCP client with server parameters.

View File

@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
@@ -224,7 +224,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
async def process_frame(self, frame: Frame, direction: FrameDirection):

View File

@@ -64,7 +64,6 @@ class OpenAITTSService(TTSService):
model: str = "gpt-4o-mini-tts",
sample_rate: Optional[int] = None,
instructions: Optional[str] = None,
speed: Optional[float] = None,
**kwargs,
):
"""Initialize OpenAI TTS service.
@@ -76,7 +75,6 @@ class OpenAITTSService(TTSService):
model: TTS model to use. Defaults to "gpt-4o-mini-tts".
sample_rate: Output audio sample rate in Hz. If None, uses OpenAI's default 24kHz.
instructions: Optional instructions to guide voice synthesis behavior.
speed: Voice speed control (0.25 to 4.0, default 1.0).
**kwargs: Additional keyword arguments passed to TTSService.
"""
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
@@ -86,7 +84,6 @@ class OpenAITTSService(TTSService):
)
super().__init__(sample_rate=sample_rate, **kwargs)
self._speed = speed
self.set_model_name(model)
self.set_voice(voice)
self._instructions = instructions
@@ -136,22 +133,17 @@ class OpenAITTSService(TTSService):
try:
await self.start_ttfb_metrics()
# Setup API parameters
create_params = {
"input": text,
"model": self.model_name,
"voice": VALID_VOICES[self._voice_id],
"response_format": "pcm",
}
# Setup extra body parameters
extra_body = {}
if self._instructions:
create_params["instructions"] = self._instructions
if self._speed:
create_params["speed"] = self._speed
extra_body["instructions"] = self._instructions
async with self._client.audio.speech.with_streaming_response.create(
**create_params
input=text,
model=self.model_name,
voice=VALID_VOICES[self._voice_id],
response_format="pcm",
extra_body=extra_body,
) as r:
if r.status_code != 200:
error = await r.text()

View File

@@ -23,7 +23,6 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -32,6 +31,7 @@ from pipecat.frames.frames import (
LLMTextFrame,
LLMUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
@@ -366,7 +366,7 @@ class OpenAIRealtimeLLMService(LLMService):
elif isinstance(frame, InputAudioRawFrame):
if not self._audio_input_paused:
await self._send_user_audio(frame)
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
await self._handle_interruption()
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
@@ -716,12 +716,14 @@ class OpenAIRealtimeLLMService(LLMService):
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()
await self.push_interruption_task_frame_and_wait()
await self._start_interruption() # cancels this processor task
await self.push_frame(StartInterruptionFrame()) # cancels downstream tasks
await self.push_frame(UserStartedSpeakingFrame())
async def _handle_evt_speech_stopped(self, evt):
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._stop_interruption()
await self.push_frame(UserStoppedSpeakingFrame())
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):

View File

@@ -24,7 +24,6 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -33,6 +32,7 @@ from pipecat.frames.frames import (
LLMTextFrame,
LLMUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
@@ -364,7 +364,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
elif isinstance(frame, InputAudioRawFrame):
if not self._audio_input_paused:
await self._send_user_audio(frame)
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
await self._handle_interruption()
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
@@ -658,12 +658,14 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()
await self.push_interruption_task_frame_and_wait()
await self._start_interruption() # cancels this processor task
await self.push_frame(StartInterruptionFrame()) # cancels downstream tasks
await self.push_frame(UserStartedSpeakingFrame())
async def _handle_evt_speech_stopped(self, evt):
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._stop_interruption()
await self.push_frame(UserStoppedSpeakingFrame())
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):

View File

@@ -25,8 +25,8 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -312,7 +312,7 @@ class PlayHTTTSService(InterruptibleTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
"""Handle interruption by stopping metrics and clearing request ID."""
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()

View File

@@ -24,14 +24,15 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import AudioContextWordTTSService, TTSService
from pipecat.transcriptions import language
from pipecat.transcriptions.language import Language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
@@ -279,7 +280,7 @@ class RimeTTSService(AudioContextWordTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
"""Handle interruption by clearing current context."""
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
@@ -374,7 +375,7 @@ class RimeTTSService(AudioContextWordTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("Reset", 0)])

View File

@@ -20,9 +20,9 @@ from pipecat.frames.frames import (
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -455,7 +455,7 @@ class SarvamTTSService(InterruptibleTTSService):
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
async def process_frame(self, frame: Frame, direction: FrameDirection):

View File

@@ -15,8 +15,8 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterruptionFrame,
OutputImageRawFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
UserStartedSpeakingFrame,
@@ -179,7 +179,7 @@ class SimliVideoService(FrameProcessor):
return
elif isinstance(frame, (EndFrame, CancelFrame)):
await self._stop()
elif isinstance(frame, (InterruptionFrame, UserStartedSpeakingFrame)):
elif isinstance(frame, (StartInterruptionFrame, UserStartedSpeakingFrame)):
if not self._previously_interrupted:
await self._simli_client.clearBuffer()
self._previously_interrupted = self._is_trinity_avatar

View File

@@ -19,6 +19,7 @@ from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EndFrame,
ErrorFrame,
@@ -748,13 +749,14 @@ class SpeechmaticsSTTService(STTService):
return
# Frames to send
upstream_frames: list[Frame] = []
downstream_frames: list[Frame] = []
# If VAD is enabled, then send a speaking frame
if self._params.enable_vad and not self._is_speaking:
logger.debug("User started speaking")
self._is_speaking = True
await self.push_interruption_task_frame_and_wait()
upstream_frames += [BotInterruptionFrame()]
downstream_frames += [UserStartedSpeakingFrame()]
# If final, then re-parse into TranscriptionFrame
@@ -792,6 +794,10 @@ class SpeechmaticsSTTService(STTService):
self._is_speaking = False
downstream_frames += [UserStoppedSpeakingFrame()]
# Send UPSTREAM frames
for frame in upstream_frames:
await self.push_frame(frame, FrameDirection.UPSTREAM)
# Send the DOWNSTREAM frames
for frame in downstream_frames:
await self.push_frame(frame, FrameDirection.DOWNSTREAM)

View File

@@ -23,12 +23,12 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterruptionFrame,
OutputAudioRawFrame,
OutputImageRawFrame,
OutputTransportReadyFrame,
SpeechOutputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
)
@@ -222,7 +222,7 @@ class TavusVideoService(AIService):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions()
await self.push_frame(frame, direction)
elif isinstance(frame, TTSAudioRawFrame):

View File

@@ -20,10 +20,10 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
TTSAudioRawFrame,
@@ -309,7 +309,7 @@ class TTSService(AIService):
and not isinstance(frame, TranscriptionFrame)
):
await self._process_text_frame(frame)
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
await self._handle_interruption(frame, direction)
await self.push_frame(frame, direction)
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
@@ -367,14 +367,14 @@ class TTSService(AIService):
await super().push_frame(frame, direction)
if self._push_stop_frames and (
isinstance(frame, InterruptionFrame)
isinstance(frame, StartInterruptionFrame)
or isinstance(frame, TTSStartedFrame)
or isinstance(frame, TTSAudioRawFrame)
or isinstance(frame, TTSStoppedFrame)
):
await self._stop_frame_queue.put(frame)
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
self._processing_text = False
await self._text_aggregator.handle_interruption()
for filter in self._text_filters:
@@ -438,7 +438,7 @@ class TTSService(AIService):
)
if isinstance(frame, TTSStartedFrame):
has_started = True
elif isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
has_started = False
except asyncio.TimeoutError:
if has_started:
@@ -523,7 +523,7 @@ class WordTTSService(TTSService):
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
self._llm_response_started = False
self.reset_word_timestamps()
@@ -613,7 +613,7 @@ class InterruptibleTTSService(WebsocketTTSService):
# user interrupts we need to reconnect.
self._bot_speaking = False
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
if self._bot_speaking:
await self._disconnect()
@@ -685,7 +685,7 @@ class InterruptibleWordTTSService(WebsocketWordTTSService):
# user interrupts we need to reconnect.
self._bot_speaking = False
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
if self._bot_speaking:
await self._disconnect()
@@ -813,7 +813,7 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
await super().cancel(frame)
await self._stop_audio_context_task()
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self._stop_audio_context_task()
self._create_audio_context_task()

View File

@@ -128,7 +128,7 @@ async def run_test(
expected_up_frames: Optional[Sequence[type]] = None,
ignore_start: bool = True,
observers: Optional[List[BaseObserver]] = None,
pipeline_params: Optional[PipelineParams] = None,
start_metadata: Optional[Dict[str, Any]] = None,
send_end_frame: bool = True,
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
"""Run a test pipeline with the specified processor and validate frame flow.
@@ -144,7 +144,7 @@ async def run_test(
expected_up_frames: Expected frame types flowing upstream (optional).
ignore_start: Whether to ignore StartFrames in frame validation.
observers: Optional list of observers to attach to the pipeline.
pipeline_params: Optional pipeline parameters.
start_metadata: Optional metadata to include with the StartFrame.
send_end_frame: Whether to send an EndFrame at the end of the test.
Returns:
@@ -154,7 +154,7 @@ async def run_test(
AssertionError: If the received frames don't match the expected frame types.
"""
observers = observers or []
pipeline_params = pipeline_params or PipelineParams()
start_metadata = start_metadata or {}
received_up = asyncio.Queue()
received_down = asyncio.Queue()
@@ -173,7 +173,7 @@ async def run_test(
task = PipelineTask(
pipeline,
params=pipeline_params,
params=PipelineParams(start_metadata=start_metadata),
observers=observers,
cancel_on_idle_timeout=False,
)

View File

@@ -22,6 +22,7 @@ from pipecat.audio.turn.base_turn_analyzer import (
)
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -35,6 +36,7 @@ from pipecat.frames.frames import (
MetricsFrame,
SpeechControlParamsFrame,
StartFrame,
StartInterruptionFrame,
StopFrame,
SystemFrame,
UserSpeakingFrame,
@@ -287,6 +289,8 @@ class BaseInputTransport(FrameProcessor):
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
await self._handle_bot_interruption(frame)
elif isinstance(frame, BotStartedSpeakingFrame):
await self._handle_bot_started_speaking(frame)
await self.push_frame(frame, direction)
@@ -331,6 +335,13 @@ class BaseInputTransport(FrameProcessor):
# Handle interruptions
#
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
"""Handle bot interruption frames."""
logger.debug("Bot interruption")
if self.interruptions_allowed:
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
async def _handle_user_interruption(self, vad_state: VADState, emulated: bool = False):
"""Handle user interruption events based on speaking state."""
if vad_state == VADState.SPEAKING:
@@ -342,7 +353,7 @@ class BaseInputTransport(FrameProcessor):
await self.push_frame(downstream_frame)
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
# Only push InterruptionFrame if:
# Only push StartInterruptionFrame if:
# 1. No interruption config is set, OR
# 2. Interruption config is set but bot is not speaking
should_push_immediate_interruption = (
@@ -351,7 +362,11 @@ class BaseInputTransport(FrameProcessor):
# Make sure we notify about interruptions quickly out-of-band.
if should_push_immediate_interruption and self.interruptions_allowed:
await self.push_interruption_task_frame_and_wait()
await self._start_interruption()
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(StartInterruptionFrame())
elif self.interruption_strategies and self._bot_speaking:
logger.debug(
"User started speaking while bot is speaking with interruption config - "
@@ -366,6 +381,9 @@ class BaseInputTransport(FrameProcessor):
await self.push_frame(downstream_frame)
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
if self.interruptions_allowed:
await self._stop_interruption()
#
# Handle bot speaking state
#

View File

@@ -30,7 +30,6 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputTransportMessageUrgentFrame,
InterruptionFrame,
MixerControlFrame,
OutputAudioRawFrame,
OutputDTMFFrame,
@@ -40,6 +39,7 @@ from pipecat.frames.frames import (
SpeechOutputAudioRawFrame,
SpriteFrame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
@@ -287,8 +287,9 @@ class BaseOutputTransport(FrameProcessor):
await super().process_frame(frame, direction)
#
# System frames (like InterruptionFrame) are pushed immediately. Other
# frames require order so they are put in the sink queue.
# System frames (like StartInterruptionFrame) are pushed
# immediately. Other frames require order so they are put in the sink
# queue.
#
if isinstance(frame, StartFrame):
# Push StartFrame before start(), because we want StartFrame to be
@@ -298,7 +299,7 @@ class BaseOutputTransport(FrameProcessor):
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
elif isinstance(frame, StartInterruptionFrame):
await self.push_frame(frame, direction)
await self._handle_frame(frame)
elif isinstance(frame, TransportMessageUrgentFrame) and not isinstance(
@@ -339,7 +340,7 @@ class BaseOutputTransport(FrameProcessor):
sender = self._media_senders[frame.transport_destination]
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await sender.handle_interruptions(frame)
elif isinstance(frame, OutputAudioRawFrame):
await sender.handle_audio_frame(frame)
@@ -490,7 +491,7 @@ class BaseOutputTransport(FrameProcessor):
await self._cancel_clock_task()
await self._cancel_video_task()
async def handle_interruptions(self, _: InterruptionFrame):
async def handle_interruptions(self, _: StartInterruptionFrame):
"""Handle interruption events by restarting tasks and clearing buffers.
Args:
@@ -671,7 +672,7 @@ class BaseOutputTransport(FrameProcessor):
frame = self._audio_queue.get_nowait()
if isinstance(frame, OutputAudioRawFrame):
frame.audio = await self._mixer.mix(frame.audio)
last_frame_time = time.time()
last_frame_time = time.time()
yield frame
except asyncio.QueueEmpty:
# Notify the bot stopped speaking upstream if necessary.

View File

@@ -25,7 +25,6 @@ from pydantic import BaseModel
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
from pipecat.frames.frames import (
CancelFrame,
ControlFrame,
EndFrame,
ErrorFrame,
Frame,
@@ -42,7 +41,6 @@ from pipecat.frames.frames import (
UserAudioRawFrame,
UserImageRawFrame,
UserImageRequestFrame,
DataFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
from pipecat.transcriptions.language import Language
@@ -107,17 +105,6 @@ class DailyInputTransportMessageUrgentFrame(InputTransportMessageUrgentFrame):
participant_id: Optional[str] = None
@dataclass
class DailyUpdateRemoteParticipantsFrame(ControlFrame):
"""Frame to update remote participants in Daily calls.
Parameters:
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
"""
remote_participants: Mapping[str, Any] = None
class WebRTCVADAnalyzer(VADAnalyzer):
"""Voice Activity Detection analyzer using WebRTC.
@@ -228,7 +215,6 @@ class DailyCallbacks(BaseModel):
on_active_speaker_changed: Called when the active speaker of the call has changed.
on_joined: Called when bot successfully joined a room.
on_left: Called when bot left a room.
on_before_leave: Called when bot is about to leave the room.
on_error: Called when an error occurs.
on_app_message: Called when receiving an app message.
on_call_state_updated: Called when call state changes.
@@ -258,7 +244,6 @@ class DailyCallbacks(BaseModel):
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
on_left: Callable[[], Awaitable[None]]
on_before_leave: Callable[[], Awaitable[None]]
on_error: Callable[[str], Awaitable[None]]
on_app_message: Callable[[Any, str], Awaitable[None]]
on_call_state_updated: Callable[[str], Awaitable[None]]
@@ -374,7 +359,6 @@ class DailyTransportClient(EventHandler):
self._transcription_ids = []
self._transcription_status = None
self._dial_out_session_id: str = ""
self._dial_in_session_id: str = ""
self._joining = False
self._joined = False
@@ -735,9 +719,6 @@ class DailyTransportClient(EventHandler):
logger.info(f"Leaving {self._room_url}")
# Call callback before leaving.
await self._callbacks.on_before_leave()
if self._params.transcription_enabled:
await self.stop_transcription()
@@ -842,16 +823,6 @@ class DailyTransportClient(EventHandler):
Args:
settings: SIP call transfer settings.
"""
session_id = (
settings.get("sessionId") or self._dial_out_session_id or self._dial_in_session_id
)
if not session_id:
logger.error("Unable to transfer SIP call: 'sessionId' is not set")
return
# Update 'sessionId' field.
settings["sessionId"] = session_id
future = self._get_event_loop().create_future()
self._client.sip_call_transfer(settings, completion=completion_callback(future))
await future
@@ -1170,7 +1141,6 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in connection data.
"""
self._dial_in_session_id = data["sessionId"] if "sessionId" in data else ""
self._call_event_callback(self._callbacks.on_dialin_connected, data)
def on_dialin_ready(self, sip_endpoint: str):
@@ -1187,9 +1157,6 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in stop data.
"""
# Cleanup only if our session stopped.
if data.get("sessionId") == self._dial_in_session_id:
self._dial_in_session_id = ""
self._call_event_callback(self._callbacks.on_dialin_stopped, data)
def on_dialin_error(self, data: Any):
@@ -1198,9 +1165,6 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in error data.
"""
# Cleanup only if our session errored out.
if data.get("sessionId") == self._dial_in_session_id:
self._dial_in_session_id = ""
self._call_event_callback(self._callbacks.on_dialin_error, data)
def on_dialin_warning(self, data: Any):
@@ -1235,7 +1199,7 @@ class DailyTransportClient(EventHandler):
data: Dial-out stop data.
"""
# Cleanup only if our session stopped.
if data.get("sessionId") == self._dial_out_session_id:
if data["sessionId"] == self._dial_out_session_id:
self._dial_out_session_id = ""
self._call_event_callback(self._callbacks.on_dialout_stopped, data)
@@ -1246,7 +1210,7 @@ class DailyTransportClient(EventHandler):
data: Dial-out error data.
"""
# Cleanup only if our session errored out.
if data.get("sessionId") == self._dial_out_session_id:
if data["sessionId"] == self._dial_out_session_id:
self._dial_out_session_id = ""
self._call_event_callback(self._callbacks.on_dialout_error, data)
@@ -1803,31 +1767,6 @@ class DailyOutputTransport(BaseOutputTransport):
# Leave the room.
await self._client.leave()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process outgoing frames, including transport messages.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
logger.debug(f"Got a DailyUpdateRemoteParticipantsFrame: {frame}")
await self._client.update_remote_participants(frame.remote_participants)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process outgoing frames, including transport messages.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
await self._client.update_remote_participants(frame.remote_participants)
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
"""Send a transport message to participants.
@@ -1923,7 +1862,6 @@ class DailyTransport(BaseTransport):
on_active_speaker_changed=self._on_active_speaker_changed,
on_joined=self._on_joined,
on_left=self._on_left,
on_before_leave=self._on_before_leave,
on_error=self._on_error,
on_app_message=self._on_app_message,
on_call_state_updated=self._on_call_state_updated,
@@ -1987,10 +1925,6 @@ class DailyTransport(BaseTransport):
self._register_event_handler("on_recording_started")
self._register_event_handler("on_recording_stopped")
self._register_event_handler("on_recording_error")
self._register_event_handler("on_before_disconnect", sync=True)
# Deprecated
self._register_event_handler("on_joined")
self._register_event_handler("on_left")
#
# BaseTransport
@@ -2242,10 +2176,6 @@ class DailyTransport(BaseTransport):
"""Handle room left events."""
await self._call_event_handler("on_left")
async def _on_before_leave(self):
"""Handle before leave room events."""
await self._call_event_handler("on_before_disconnect")
async def _on_error(self, error):
"""Handle error events and push error frames."""
await self._call_event_handler("on_error", error)
@@ -2385,7 +2315,7 @@ class DailyTransport(BaseTransport):
"""Handle participant updated events."""
await self._call_event_handler("on_participant_updated", participant)
async def _on_transcription_message(self, message: Mapping[str, Any]) -> None:
async def _on_transcription_message(self, message: Dict[str, Any]) -> None:
"""Handle transcription message events."""
await self._call_event_handler("on_transcription_message", message)

View File

@@ -114,7 +114,6 @@ class LiveKitCallbacks(BaseModel):
on_connected: Callable[[], Awaitable[None]]
on_disconnected: Callable[[], Awaitable[None]]
on_before_disconnect: Callable[[], Awaitable[None]]
on_participant_connected: Callable[[str], Awaitable[None]]
on_participant_disconnected: Callable[[str], Awaitable[None]]
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
@@ -283,7 +282,6 @@ class LiveKitTransportClient:
return
logger.info(f"Disconnecting from {self._room_name}")
await self._callbacks.on_before_disconnect()
await self.room.disconnect()
self._connected = False
logger.info(f"Disconnected from {self._room_name}")
@@ -920,7 +918,6 @@ class LiveKitTransport(BaseTransport):
callbacks = LiveKitCallbacks(
on_connected=self._on_connected,
on_disconnected=self._on_disconnected,
on_before_disconnect=self._on_before_disconnect,
on_participant_connected=self._on_participant_connected,
on_participant_disconnected=self._on_participant_disconnected,
on_audio_track_subscribed=self._on_audio_track_subscribed,
@@ -950,7 +947,6 @@ class LiveKitTransport(BaseTransport):
self._register_event_handler("on_first_participant_joined")
self._register_event_handler("on_participant_left")
self._register_event_handler("on_call_state_updated")
self._register_event_handler("on_before_disconnect", sync=True)
def input(self) -> LiveKitInputTransport:
"""Get the input transport for receiving media and events.
@@ -1045,10 +1041,6 @@ class LiveKitTransport(BaseTransport):
"""Handle room disconnected events."""
await self._call_event_handler("on_disconnected")
async def _on_before_disconnect(self):
"""Handle before disconnection room events."""
await self._call_event_handler("on_before_disconnect")
async def _on_participant_connected(self, participant_id: str):
"""Handle participant connected events."""
await self._call_event_handler("on_participant_connected", participant_id)

View File

@@ -95,20 +95,15 @@ class SmallWebRTCTrack:
enable/disable control and frame discarding for audio and video streams.
"""
def __init__(self, receiver):
def __init__(self, track: MediaStreamTrack):
"""Initialize the WebRTC track wrapper.
Args:
receiver: The RemoteStreamTrack receiver instance.
track: The underlying MediaStreamTrack to wrap.
index: The index of the track in the transceiver (0 for mic, 1 for cam, 2 for screen)
"""
self._receiver = receiver
# Configuring the receiver for not consuming the track by default to prevent memory grow
self._receiver._enabled = False
self._track = receiver.track
self._track = track
self._enabled = True
self._last_recv_time: float = 0.0
self._idle_task: Optional[asyncio.Task] = None
self._idle_timeout: float = 2.0 # seconds before discarding old frames
def set_enabled(self, enabled: bool) -> None:
"""Enable or disable the track.
@@ -143,44 +138,13 @@ class SmallWebRTCTrack:
async def recv(self) -> Optional[Frame]:
"""Receive the next frame from the track.
Enables the internal receiving state and starts idle watcher.
Returns:
The next frame, except for video tracks, where it returns the frame only if the track is enabled, otherwise, returns None.
"""
self._receiver._enabled = True
self._last_recv_time = time.time()
# start idle watcher if not already running
if not self._idle_task or self._idle_task.done():
self._idle_task = asyncio.create_task(self._idle_watcher())
if not self._enabled and self._track.kind == "video":
return None
return await self._track.recv()
async def _idle_watcher(self):
"""Disable receiving if idle for more than _idle_timeout and monitor queue size."""
while self._receiver._enabled:
await asyncio.sleep(self._idle_timeout)
idle_duration = time.time() - self._last_recv_time
if idle_duration >= self._idle_timeout:
# discard old frames to prevent memory growth
logger.debug(
f"Disabling receiver for {self._track.kind} track after {idle_duration:.2f}s idle"
)
await self.discard_old_frames()
self._receiver._enabled = False
def stop(self):
"""Stop receiving frames from the track."""
self._receiver._enabled = False
if self._idle_task:
self._idle_task.cancel()
self._idle_task = None
if self._track:
self._track.stop()
def __getattr__(self, name):
"""Forward attribute access to the underlying track.
@@ -490,10 +454,6 @@ class SmallWebRTCConnection(BaseObject):
async def _close(self):
"""Close the peer connection and cleanup resources."""
for track in self._track_map.values():
if track:
track.stop()
self._track_map.clear()
if self._pc:
await self._pc.close()
self._message_queue.clear()
@@ -566,8 +526,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No audio transceiver is available")
return None
receiver = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver
audio_track = SmallWebRTCTrack(receiver) if receiver else None
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
audio_track = SmallWebRTCTrack(track) if track else None
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
return audio_track
@@ -588,8 +548,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No video transceiver is available")
return None
receiver = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver
video_track = SmallWebRTCTrack(receiver) if receiver else None
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track
@@ -610,8 +570,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No screen video transceiver is available")
return None
receiver = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver
video_track = SmallWebRTCTrack(receiver) if receiver else None
track = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
self._track_map[SCREEN_VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track

View File

@@ -1,200 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""SmallWebRTC request handler for managing peer connections.
This module provides a client for handling web requests and managing WebRTC connections.
"""
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Optional
from fastapi import HTTPException
from loguru import logger
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
@dataclass
class SmallWebRTCRequest:
"""Small WebRTC transport session arguments for the runner.
Parameters:
sdp: The SDP string (Session Description Protocol).
type: The type of the SDP, either "offer" or "answer".
pc_id: Optional identifier for the peer connection.
restart_pc: Optional whether to restart the peer connection.
request_data: Optional custom data sent by the customer.
"""
sdp: str
type: str
pc_id: Optional[str] = None
restart_pc: Optional[bool] = None
request_data: Optional[Any] = None
class ConnectionMode(Enum):
"""Enum defining the connection handling modes."""
SINGLE = "single" # Only one active connection allowed
MULTIPLE = "multiple" # Multiple simultaneous connections allowed
class SmallWebRTCRequestHandler:
"""SmallWebRTC request handler for managing peer connections.
This class is responsible for:
- Handling incoming SmallWebRTC requests.
- Creating and managing WebRTC peer connections.
- Supporting ESP32-specific SDP munging if enabled.
- Invoking callbacks for newly initialized connections.
- Supporting both single and multiple connection modes.
"""
def __init__(
self,
ice_servers: Optional[List[IceServer]] = None,
esp32_mode: bool = False,
host: Optional[str] = None,
connection_mode: ConnectionMode = ConnectionMode.MULTIPLE,
) -> None:
"""Initialize a SmallWebRTC request handler.
Args:
ice_servers (Optional[List[IceServer]]): List of ICE servers to use for WebRTC
connections.
esp32_mode (bool): If True, enables ESP32-specific SDP munging.
host (Optional[str]): Host address used for SDP munging in ESP32 mode.
Ignored if `esp32_mode` is False.
connection_mode (ConnectionMode): Mode of operation for handling connections.
SINGLE allows only one active connection, MULTIPLE allows several.
"""
self._ice_servers = ice_servers
self._esp32_mode = esp32_mode
self._host = host
self._connection_mode = connection_mode
# Store connections by pc_id
self._pcs_map: Dict[str, SmallWebRTCConnection] = {}
def _check_single_connection_constraints(self, pc_id: Optional[str]) -> None:
"""Check if the connection request satisfies single connection mode constraints.
Args:
pc_id: The peer connection ID from the request
Raises:
HTTPException: If constraints are violated in single connection mode
"""
if self._connection_mode != ConnectionMode.SINGLE:
return
if not self._pcs_map: # No existing connections
return
# Get the existing connection (should be only one in single mode)
existing_connection = next(iter(self._pcs_map.values()))
if existing_connection.pc_id != pc_id and pc_id:
logger.warning(
f"Connection pc_id mismatch: existing={existing_connection.pc_id}, received={pc_id}"
)
raise HTTPException(status_code=400, detail="PC ID mismatch with existing connection")
if not pc_id:
logger.warning(
"Cannot create new connection: existing connection found but no pc_id received"
)
raise HTTPException(
status_code=400,
detail="Cannot create new connection with existing connection active",
)
async def handle_web_request(
self,
request: SmallWebRTCRequest,
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
) -> None:
"""Handle a SmallWebRTC request and resolve the pending answer.
This method will:
- Reuse an existing WebRTC connection if `pc_id` exists.
- Otherwise, create a new `SmallWebRTCConnection`.
- Invoke the provided callback with the connection.
- Manage ESP32-specific munging if enabled.
- Enforce single/multiple connection mode constraints.
Args:
request (SmallWebRTCRequest): The incoming WebRTC request, containing
SDP, type, and optionally a `pc_id`.
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
asynchronous callback function that is invoked with the WebRTC connection.
Raises:
HTTPException: If connection mode constraints are violated
Exception: Any exception raised during request handling or callback execution
will be logged and propagated.
"""
try:
pc_id = request.pc_id
# Check connection mode constraints first
self._check_single_connection_constraints(pc_id)
# After constraints are satisfied, get the existing connection if any
existing_connection = self._pcs_map.get(pc_id) if pc_id else None
if existing_connection:
pipecat_connection = existing_connection
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request.sdp,
type=request.type,
restart_pc=request.restart_pc or False,
)
else:
pipecat_connection = SmallWebRTCConnection(ice_servers=self._ice_servers)
await pipecat_connection.initialize(sdp=request.sdp, type=request.type)
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
self._pcs_map.pop(webrtc_connection.pc_id, None)
# Invoke callback provided in runner arguments
try:
await webrtc_connection_callback(pipecat_connection)
logger.debug(
f"webrtc_connection_callback executed successfully for peer: {pipecat_connection.pc_id}"
)
except Exception as callback_error:
logger.error(
f"webrtc_connection_callback failed for peer {pipecat_connection.pc_id}: {callback_error}"
)
answer = pipecat_connection.get_answer()
if self._esp32_mode and self._host and self._host != "localhost":
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], self._host)
self._pcs_map[answer["pc_id"]] = pipecat_connection
return answer
except Exception as e:
logger.error(f"Error processing SmallWebRTC request: {e}")
logger.debug(f"SmallWebRTC request details: {request}")
raise
async def close(self):
"""Clear the connection map."""
coros = [pc.disconnect() for pc in self._pcs_map.values()]
await asyncio.gather(*coros)
self._pcs_map.clear()

View File

@@ -478,11 +478,7 @@ class SmallWebRTCClient:
self._screen_video_track = None
self._audio_output_track = None
self._video_output_track = None
# Trigger `on_client_disconnected` if the client actually disconnects,
# that is, we are not the ones disconnecting.
if not self._closing:
await self._callbacks.on_client_disconnected(self._webrtc_connection)
await self._callbacks.on_client_disconnected(self._webrtc_connection)
async def _handle_app_message(self, message: Any):
"""Handle incoming application messages."""

View File

@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
InterruptionFrame,
OutputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
@@ -618,7 +618,7 @@ class TavusOutputTransport(BaseOutputTransport):
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions()
async def _handle_interruptions(self):

View File

@@ -26,9 +26,9 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
InterruptionFrame,
OutputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
@@ -138,6 +138,7 @@ class FastAPIWebsocketClient:
):
logger.warning("Closing already disconnected websocket!")
self._closing = True
await self.trigger_client_disconnected()
async def disconnect(self):
"""Disconnect the WebSocket client."""
@@ -151,6 +152,8 @@ class FastAPIWebsocketClient:
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception while closing the websocket: {e}")
finally:
await self.trigger_client_disconnected()
async def trigger_client_disconnected(self):
"""Trigger the client disconnected callback."""
@@ -295,10 +298,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
# Trigger `on_client_disconnected` if the client actually disconnects,
# that is, we are not the ones disconnecting.
if not self._client.is_closing:
await self._client.trigger_client_disconnected()
await self._client.trigger_client_disconnected()
async def _monitor_websocket(self):
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
@@ -398,7 +398,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._write_frame(frame)
self._next_send_time = 0
@@ -446,9 +446,6 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
async def _write_frame(self, frame: Frame):
"""Serialize and send a frame through the WebSocket."""
if self._client.is_closing or not self._client.is_connected:
return
if not self._params.serializer:
return

View File

@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
InterruptionFrame,
OutputAudioRawFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
@@ -334,7 +334,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
if isinstance(frame, StartInterruptionFrame):
await self._write_frame(frame)
self._next_send_time = 0

View File

@@ -14,33 +14,13 @@ and async cleanup for all Pipecat components.
import asyncio
import inspect
from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Optional
from loguru import logger
from pipecat.utils.utils import obj_count, obj_id
@dataclass
class EventHandler:
"""Data class to store event handlers information.
This data class stores the event name, a list of handlers to run for this
event, and whether these handlers will be executed in a task.
Attributes:
name (str): The name of the event handler.
handlers (List[Any]): A list of functions to be called when this event is triggered.
is_sync (bool): Indicates whether the functions are executed in a task.
"""
name: str
handlers: List[Any]
is_sync: bool
class BaseObject(ABC):
"""Abstract base class providing common functionality for Pipecat objects.
@@ -61,7 +41,7 @@ class BaseObject(ABC):
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
# Registered event handlers.
self._event_handlers: Dict[str, EventHandler] = {}
self._event_handlers: dict = {}
# Set of tasks being executed. When a task finishes running it gets
# automatically removed from the set. When we cleanup we wait for all
@@ -123,21 +103,18 @@ class BaseObject(ABC):
Can be sync or async.
"""
if event_name in self._event_handlers:
self._event_handlers[event_name].handlers.append(handler)
self._event_handlers[event_name].append(handler)
else:
logger.warning(f"Event handler {event_name} not registered")
def _register_event_handler(self, event_name: str, sync: bool = False):
def _register_event_handler(self, event_name: str):
"""Register an event handler type.
Args:
event_name: The name of the event type to register.
sync: Whether this event handler will be executed in a task.
"""
if event_name not in self._event_handlers:
self._event_handlers[event_name] = EventHandler(
name=event_name, handlers=[], is_sync=sync
)
self._event_handlers[event_name] = []
else:
logger.warning(f"Event handler {event_name} not registered")
@@ -149,43 +126,34 @@ class BaseObject(ABC):
*args: Positional arguments to pass to event handlers.
**kwargs: Keyword arguments to pass to event handlers.
"""
if event_name not in self._event_handlers:
# If we haven't registered an event handler, we don't need to do
# anything.
if not self._event_handlers.get(event_name):
return
event_handler = self._event_handlers[event_name]
# Create the task.
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
for handler in event_handler.handlers:
if event_handler.is_sync:
# Just run the handler.
await self._run_handler(event_handler.name, handler, *args, **kwargs)
else:
# Create the task. Note that this is a task per each function
# handler. Users can register to an event handler multiple
# times.
task = asyncio.create_task(
self._run_handler(event_handler.name, handler, *args, **kwargs)
)
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
async def _run_handler(self, event_name: str, handler, *args, **kwargs):
async def _run_task(self, event_name: str, *args, **kwargs):
"""Execute all handlers for an event.
Args:
event_name: The event name for this handler.
handler: The handler function to run.
event_name: The name of the event being handled.
*args: Positional arguments to pass to handlers.
**kwargs: Keyword arguments to pass to handlers.
"""
try:
if inspect.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
handler(self, *args, **kwargs)
for handler in self._event_handlers[event_name]:
if inspect.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
handler(self, *args, **kwargs)
except Exception as e:
logger.exception(f"Exception in event handler {event_name}: {e}")

View File

@@ -8,31 +8,25 @@ import json
import unittest
from typing import Any
from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallResultProperties,
InterimTranscriptionFrame,
InterruptionFrame,
InterruptionTaskFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
OpenAILLMContextAssistantTimestampFrame,
SpeechControlParamsFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -42,7 +36,6 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.anthropic.llm import (
AnthropicAssistantContextAggregator,
AnthropicLLMContext,
@@ -488,103 +481,6 @@ class BaseTestUserContextAggregator:
)
self.check_message_content(context, 0, "How are you?")
async def test_min_words_interruption_strategy_one_word(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
class ContextProcessor(FrameProcessor):
def __init__(self):
super().__init__()
self.context_received = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
self.context_received = True
await self.push_frame(frame, direction)
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
context_processor = ContextProcessor()
pipeline = Pipeline([aggregator, context_processor])
frames_to_send = [
BotStartedSpeakingFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Can", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_down_frames = [
BotStartedSpeakingFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
pipeline_params=PipelineParams(
interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)]
),
)
assert not context_processor.context_received
async def test_min_words_interruption_strategy_two_words(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
class ContextProcessor(FrameProcessor):
def __init__(self):
super().__init__()
self.context_received = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
self.context_received = True
elif isinstance(frame, InterruptionFrame):
self.context_received = False
await self.push_frame(frame, direction)
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
context_processor = ContextProcessor()
pipeline = Pipeline([aggregator, context_processor])
frames_to_send = [
BotStartedSpeakingFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Can you", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_up_frames = [InterruptionTaskFrame]
expected_down_frames = [
BotStartedSpeakingFrame,
UserStartedSpeakingFrame,
InterruptionFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_up_frames=expected_up_frames,
expected_down_frames=expected_down_frames,
pipeline_params=PipelineParams(
interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)]
),
)
self.check_message_content(context, 0, "Can you")
# If the context is not received or it has been cleared by the
# interruption then we have an issue.
assert context_processor.context_received
class BaseTestAssistantContextAggreagator:
CONTEXT_CLASS = None # To be set in subclasses
@@ -722,7 +618,7 @@ class BaseTestAssistantContextAggreagator:
TextFrame(text="Pipecat."),
LLMFullResponseEndFrame(),
SleepFrame(AGGREGATION_SLEEP),
InterruptionFrame(),
StartInterruptionFrame(),
LLMFullResponseStartFrame(),
TextFrame(text="How are "),
TextFrame(text="you?"),
@@ -730,7 +626,7 @@ class BaseTestAssistantContextAggreagator:
]
expected_down_frames = [
*self.EXPECTED_CONTEXT_FRAMES,
InterruptionFrame,
StartInterruptionFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(

View File

@@ -10,7 +10,6 @@ from pipecat.audio.dtmf.types import KeypadEntry
from pipecat.frames.frames import (
EndFrame,
InputDTMFFrame,
InterruptionFrame,
TranscriptionFrame,
)
from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator
@@ -29,7 +28,6 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
InputDTMFFrame,
InputDTMFFrame,
@@ -61,11 +59,9 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame, # First aggregation "12"
InputDTMFFrame,
InterruptionFrame,
TranscriptionFrame, # Second aggregation "3"
]
@@ -97,12 +93,10 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
InputDTMFFrame,
TranscriptionFrame, # "12#"
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame, # "45"
]
@@ -131,7 +125,6 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame, # Should flush before EndFrame
EndFrame,
@@ -159,7 +152,6 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame,
]
@@ -186,7 +178,6 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
InputDTMFFrame,
TranscriptionFrame,
@@ -223,11 +214,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
# All the InputDTMFFrames plus one TranscriptionFrame
expected_down_frames = (
[InputDTMFFrame, InterruptionFrame]
+ [InputDTMFFrame] * (len(frames_to_send) - 1)
+ [TranscriptionFrame]
)
expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame]
received_down_frames, _ = await run_test(
aggregator,

View File

@@ -1,67 +0,0 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import unittest
from pipecat.frames.frames import (
EndFrame,
Frame,
InterruptionFrame,
TextFrame,
TransportMessageUrgentFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import SleepFrame, run_test
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
async def test_interruption_and_wait(self):
class DelayFrameProcessor(FrameProcessor):
"""This processors just gives time to the event loop to change
between tasks. Otherwise things happen to fast."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await asyncio.sleep(0.1)
await self.push_frame(frame, direction)
class InterruptFrameProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_interruption_task_frame_and_wait()
await self.push_frame(TransportMessageUrgentFrame(message=frame.text))
else:
await self.push_frame(frame, direction)
pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()])
frames_to_send = [
# Just a random interruption to make sure we don't clear anything
# before the actual `InterruptionTaskFrame` interruption.
InterruptionFrame(),
# This will generate an `InterruptionTaskFrame` and will wait for an
# `InterruptionFrame`.
TextFrame(text="Hello from Pipecat!"),
# Just give time for everything to complete.
SleepFrame(sleep=0.5),
EndFrame(),
]
expected_down_frames = [
InterruptionFrame,
InterruptionFrame,
TransportMessageUrgentFrame,
EndFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
send_end_frame=False,
)

View File

@@ -1,998 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""
Unit tests for LLM adapters' get_llm_invocation_params() method.
These tests focus specifically on the "messages" field generation for different adapters, ensuring:
For OpenAI adapter:
1. LLMStandardMessage objects are passed through unchanged
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
3. Complex message structures (like multi-part content) are preserved
4. System instructions are preserved throughout messages at any position
For Gemini adapter:
1. LLMStandardMessage objects are converted to Gemini Content format
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
4. System messages are extracted as system_instruction (without duplication)
5. Single system instruction is converted to user message when no other messages exist
6. Multiple system instructions: first extracted, later ones converted to user messages
For Anthropic adapter:
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
4. System messages: first extracted as system parameter, later ones converted to user messages
5. Consecutive messages with same role are merged into multi-content-block messages
6. Empty text content is converted to "(empty)"
For AWS Bedrock adapter:
1. LLMStandardMessage objects are converted to AWS Bedrock format
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
4. System messages: first extracted as system parameter, later ones converted to user messages
5. Consecutive messages with same role are merged into multi-content-block messages
6. Empty text content is converted to "(empty)"
"""
import unittest
from google.genai.types import Content, Part
from openai.types.chat import ChatCompletionMessage
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMSpecificMessage,
LLMStandardMessage,
)
class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = OpenAILLMAdapter()
def test_standard_messages_passed_through_unchanged(self):
"""Test that LLMStandardMessage objects are passed through unchanged to OpenAI params."""
# Create standard messages (OpenAI format)
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
]
# Create context with these messages
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify messages are passed through unchanged
self.assertEqual(params["messages"], standard_messages)
self.assertEqual(len(params["messages"]), 3)
# Verify content matches exactly
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
def test_llm_specific_message_filtering(self):
"""Test that OpenAI-specific messages are included and others are filtered out."""
# Create messages with different LLM-specific ones
messages = [
{"role": "system", "content": "You are a helpful assistant."},
AnthropicLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Anthropic specific message"}
),
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Gemini specific message"}
),
{"role": "user", "content": "Standard user message"},
self.adapter.create_llm_specific_message(
{"role": "assistant", "content": "OpenAI specific response"}
),
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only include standard messages and OpenAI-specific ones
# (3 total: system, standard user, openai assistant)
self.assertEqual(len(params["messages"]), 3)
# Verify the correct messages are included
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
self.assertEqual(params["messages"][1]["content"], "Standard user message")
self.assertEqual(
params["messages"][2], {"role": "assistant", "content": "OpenAI specific response"}
)
def test_complex_message_content_preserved(self):
"""Test that complex message content (like multi-part messages) is preserved."""
# Create a message with complex content structure (text + image)
complex_image_message = {
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD..."},
},
],
}
# Create a message with multiple text blocks
multi_text_message = {
"role": "assistant",
"content": [
{"type": "text", "text": "Let me analyze this step by step:"},
{"type": "text", "text": "1. First, I'll examine the visual elements"},
{"type": "text", "text": "2. Then I'll provide my conclusions"},
],
}
messages = [
{"role": "system", "content": "You are a helpful assistant that can analyze images."},
complex_image_message,
multi_text_message,
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify complex content is preserved
self.assertEqual(len(params["messages"]), 3)
self.assertEqual(params["messages"][1], complex_image_message)
self.assertEqual(params["messages"][2], multi_text_message)
# Verify the image message structure is maintained
image_content = params["messages"][1]["content"]
self.assertIsInstance(image_content, list)
self.assertEqual(len(image_content), 2)
self.assertEqual(image_content[0]["type"], "text")
self.assertEqual(image_content[1]["type"], "image_url")
# Verify the multi-text message structure is maintained
text_content = params["messages"][2]["content"]
self.assertIsInstance(text_content, list)
self.assertEqual(len(text_content), 3)
for i, text_block in enumerate(text_content):
self.assertEqual(text_block["type"], "text")
self.assertEqual(text_content[0]["text"], "Let me analyze this step by step:")
self.assertEqual(text_content[1]["text"], "1. First, I'll examine the visual elements")
self.assertEqual(text_content[2]["text"], "2. Then I'll provide my conclusions")
def test_system_instructions_preserved_throughout_messages(self):
"""Test that OpenAI adapter preserves system instructions sprinkled throughout messages."""
# Create messages with system instructions at different positions
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."},
{"role": "user", "content": "Tell me about Python."},
{"role": "system", "content": "Use simple language."},
{"role": "assistant", "content": "Python is a programming language."},
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# OpenAI should preserve all messages unchanged, including multiple system messages
self.assertEqual(len(params["messages"]), 7)
# Verify system messages are preserved at their original positions
self.assertEqual(params["messages"][0]["role"], "system")
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
self.assertEqual(params["messages"][3]["role"], "system")
self.assertEqual(params["messages"][3]["content"], "Remember to be concise.")
self.assertEqual(params["messages"][5]["role"], "system")
self.assertEqual(params["messages"][5]["content"], "Use simple language.")
# Verify other messages remain unchanged
self.assertEqual(params["messages"][1]["role"], "user")
self.assertEqual(params["messages"][2]["role"], "assistant")
self.assertEqual(params["messages"][4]["role"], "user")
self.assertEqual(params["messages"][6]["role"], "assistant")
class TestGeminiGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = GeminiLLMAdapter()
def test_standard_messages_converted_to_gemini_format(self):
"""Test that LLMStandardMessage objects are converted to Gemini Content format."""
# Create standard messages (OpenAI format)
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
]
# Create context with these messages
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction is extracted
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Verify messages are converted to Gemini format (2 messages: user + model)
self.assertEqual(len(params["messages"]), 2)
# Check first message (user)
user_msg = params["messages"][0]
self.assertIsInstance(user_msg, Content)
self.assertEqual(user_msg.role, "user")
self.assertEqual(len(user_msg.parts), 1)
self.assertEqual(user_msg.parts[0].text, "Hello, how are you?")
# Check second message (assistant -> model)
model_msg = params["messages"][1]
self.assertIsInstance(model_msg, Content)
self.assertEqual(model_msg.role, "model")
self.assertEqual(len(model_msg.parts), 1)
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
def test_llm_specific_message_filtering(self):
"""Test that Gemini-specific messages are included and others are filtered out."""
# Create messages with different LLM-specific ones
messages = [
{"role": "system", "content": "You are a helpful assistant."},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific message"}
),
AnthropicLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Anthropic specific message"}
),
{"role": "user", "content": "Standard user message"},
self.adapter.create_llm_specific_message(
Content(role="model", parts=[Part(text="Gemini specific response")]),
),
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only include standard messages and Gemini-specific ones
# (2 total: converted standard user + gemini model)
self.assertEqual(len(params["messages"]), 2)
# Verify system instruction
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Verify the correct messages are included
self.assertEqual(params["messages"][0].role, "user")
self.assertEqual(params["messages"][0].parts[0].text, "Standard user message")
self.assertEqual(params["messages"][1].role, "model")
self.assertEqual(params["messages"][1].parts[0].text, "Gemini specific response")
def test_complex_message_content_preserved(self):
"""Test that complex message content (like multi-part messages) is preserved and converted.
This test covers image, audio, and multi-text content conversion to Gemini format.
"""
# Create a message with complex content structure (text + image)
# Using a minimal valid base64 image data
complex_image_message = {
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
},
},
],
}
# Create a message with multiple text blocks
multi_text_message = {
"role": "assistant",
"content": [
{"type": "text", "text": "Let me analyze this step by step:"},
{"type": "text", "text": "1. First, I'll examine the visual elements"},
{"type": "text", "text": "2. Then I'll provide my conclusions"},
],
}
# Create a message with audio input (text + audio)
# Using a minimal valid base64 audio data (16 bytes of WAV header)
audio_message = {
"role": "user",
"content": [
{"type": "text", "text": "Can you transcribe this audio?"},
{
"type": "input_audio",
"input_audio": {
"data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=",
"format": "wav",
},
},
],
}
messages = [
{
"role": "system",
"content": "You are a helpful assistant that can analyze images and audio.",
},
complex_image_message,
multi_text_message,
audio_message,
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction
self.assertEqual(
params["system_instruction"],
"You are a helpful assistant that can analyze images and audio.",
)
# Verify complex content is converted to Gemini format
# Note: Gemini adapter may add system instruction back as user message in some cases
self.assertGreaterEqual(len(params["messages"]), 3)
# Find the different message types
user_with_image = None
model_with_text = None
user_with_audio = None
for msg in params["messages"]:
if msg.role == "user" and len(msg.parts) == 2:
# Check if it's image or audio based on the text content
if hasattr(msg.parts[0], "text") and "image" in msg.parts[0].text:
user_with_image = msg
elif hasattr(msg.parts[0], "text") and "audio" in msg.parts[0].text:
user_with_audio = msg
elif msg.role == "model" and len(msg.parts) == 3:
model_with_text = msg
# Verify the image message structure is converted properly
self.assertIsNotNone(user_with_image, "Should have user message with image")
self.assertEqual(len(user_with_image.parts), 2)
# First part should be text
self.assertEqual(user_with_image.parts[0].text, "What's in this image?")
# Second part should be image data (converted to Blob)
self.assertIsNotNone(user_with_image.parts[1].inline_data)
self.assertEqual(user_with_image.parts[1].inline_data.mime_type, "image/jpeg")
# Verify the audio message structure is converted properly
self.assertIsNotNone(user_with_audio, "Should have user message with audio")
self.assertEqual(len(user_with_audio.parts), 2)
# First part should be text
self.assertEqual(user_with_audio.parts[0].text, "Can you transcribe this audio?")
# Second part should be audio data (converted to Blob)
self.assertIsNotNone(user_with_audio.parts[1].inline_data)
self.assertEqual(user_with_audio.parts[1].inline_data.mime_type, "audio/wav")
# Verify the multi-text message structure is converted properly
self.assertIsNotNone(model_with_text, "Should have model message with multi-text")
self.assertEqual(len(model_with_text.parts), 3)
# All parts should be text
expected_texts = [
"Let me analyze this step by step:",
"1. First, I'll examine the visual elements",
"2. Then I'll provide my conclusions",
]
for i, expected_text in enumerate(expected_texts):
self.assertEqual(model_with_text.parts[i].text, expected_text)
def test_single_system_instruction_converted_to_user(self):
"""Test that when there's only a system instruction, it gets converted to user message."""
# Create context with only a system message
messages = [
{"role": "system", "content": "You are a helpful assistant."},
]
context = LLMContext(messages=messages)
params = self.adapter.get_llm_invocation_params(context)
# System instruction should be extracted
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# But since there are no other messages, it should also be added back as a user message
self.assertEqual(len(params["messages"]), 1)
self.assertEqual(params["messages"][0].role, "user")
self.assertEqual(params["messages"][0].parts[0].text, "You are a helpful assistant.")
def test_multiple_system_instructions_handling(self):
"""Test that first system instruction is extracted, later ones converted to user messages."""
# Create messages with multiple system instructions
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."},
{"role": "user", "content": "Tell me about Python."},
{"role": "system", "content": "Use simple language."},
{"role": "assistant", "content": "Python is a programming language."},
]
context = LLMContext(messages=messages)
params = self.adapter.get_llm_invocation_params(context)
# First system instruction should be extracted
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Should have 6 messages (original 7 minus 1 system instruction that was extracted)
self.assertEqual(len(params["messages"]), 6)
# Find the converted system messages (should be user role now)
converted_system_messages = []
for msg in params["messages"]:
if msg.role == "user" and (
msg.parts[0].text == "Remember to be concise."
or msg.parts[0].text == "Use simple language."
):
converted_system_messages.append(msg.parts[0].text)
# Should have 2 converted system messages
self.assertEqual(len(converted_system_messages), 2)
self.assertIn("Remember to be concise.", converted_system_messages)
self.assertIn("Use simple language.", converted_system_messages)
# Verify that regular user and assistant messages are preserved
user_messages = [msg for msg in params["messages"] if msg.role == "user"]
model_messages = [msg for msg in params["messages"] if msg.role == "model"]
# Should have 4 user messages: 2 original + 2 converted from system
self.assertEqual(len(user_messages), 4)
# Should have 2 model messages (converted from assistant)
self.assertEqual(len(model_messages), 2)
class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = AnthropicLLMAdapter()
def test_standard_messages_converted_to_anthropic_format(self):
"""Test that LLMStandardMessage objects are converted to Anthropic MessageParam format."""
# Create standard messages
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
]
# Create context
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Verify system instruction is extracted
self.assertEqual(params["system"], "You are a helpful assistant.")
# Verify messages are in the params (2 messages after system extraction)
self.assertIn("messages", params)
self.assertEqual(len(params["messages"]), 2)
# Check first message (user)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertEqual(user_msg["content"], "Hello, how are you?")
# Check second message (assistant)
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
def test_llm_specific_message_filtering(self):
"""Test that Anthropic-specific messages are included and others are filtered out."""
# Create anthropic-specific message content
anthropic_message_content = {
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/jpeg", "data": "fake_data"},
},
],
}
messages = [
{"role": "user", "content": "Standard message"},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific"}
),
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Google specific"}
),
self.adapter.create_llm_specific_message(anthropic_message_content),
{"role": "assistant", "content": "Response"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
# (openai and google specific filtered out, standard + anthropic-specific merged)
self.assertEqual(len(params["messages"]), 2)
# First message: merged user message (standard + anthropic-specific)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
# Should have 3 content blocks: standard text + anthropic text + anthropic image
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["type"], "text")
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "Hello")
self.assertEqual(user_msg["content"][2]["type"], "image")
# Second message: standard response
self.assertEqual(params["messages"][1]["content"], "Response")
def test_consecutive_same_role_messages_merged(self):
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
messages = [
{"role": "user", "content": "First user message"},
{"role": "user", "content": "Second user message"},
{"role": "user", "content": "Third user message"},
{"role": "assistant", "content": "First assistant message"},
{"role": "assistant", "content": "Second assistant message"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Should have 2 messages after merging (1 user, 1 assistant)
self.assertEqual(len(params["messages"]), 2)
# Check merged user message
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["type"], "text")
self.assertEqual(user_msg["content"][0]["text"], "First user message")
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
self.assertEqual(user_msg["content"][2]["type"], "text")
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
# Check merged assistant message
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(len(assistant_msg["content"]), 2)
self.assertEqual(assistant_msg["content"][0]["type"], "text")
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
self.assertEqual(assistant_msg["content"][1]["type"], "text")
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
def test_empty_text_converted_to_empty_placeholder(self):
"""Test that empty text content is converted to "(empty)" string."""
messages = [
{"role": "user", "content": ""}, # Empty string
{
"role": "assistant",
"content": [
{"type": "text", "text": ""}, # Empty text in list content
{"type": "text", "text": "Valid text"},
],
},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Check that empty string content was converted
user_msg = params["messages"][0]
self.assertEqual(user_msg["content"], "(empty)")
# Check that empty text in list content was converted
assistant_msg = params["messages"][1]
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
def test_complex_message_content_preserved(self):
"""Test that complex message structures (text + image) are properly converted to Anthropic format."""
# Create a complex message with both text and image content
complex_message = {
"role": "user",
"content": [
{"type": "text", "text": "What do you see in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,fake_image_data"},
},
{"type": "text", "text": "Please describe it in detail."},
],
}
messages = [
complex_message,
{"role": "assistant", "content": "I can see the image clearly."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Verify complex message structure is preserved and converted
self.assertEqual(len(params["messages"]), 2)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
# Note: Anthropic adapter reorders single images to come before text, as per Anthropic docs
# Check image part (should be moved to first position and converted from image_url to image)
self.assertEqual(user_msg["content"][0]["type"], "image")
self.assertIn("source", user_msg["content"][0])
self.assertEqual(user_msg["content"][0]["source"]["type"], "base64")
self.assertEqual(user_msg["content"][0]["source"]["media_type"], "image/jpeg")
self.assertEqual(user_msg["content"][0]["source"]["data"], "fake_image_data")
# Check first text part (moved to second position)
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
# Check second text part (moved to third position)
self.assertEqual(user_msg["content"][2]["type"], "text")
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
def test_multiple_system_instructions_handling(self):
"""Test that first system instruction is extracted, later ones converted to user messages."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."}, # Later system message
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# System instruction should be extracted from first message
self.assertEqual(params["system"], "You are a helpful assistant.")
# Should have 3 messages remaining (system message was removed, later system converted to user)
self.assertEqual(len(params["messages"]), 3)
self.assertEqual(params["messages"][0]["role"], "user")
self.assertEqual(params["messages"][0]["content"], "Hello")
self.assertEqual(params["messages"][1]["role"], "assistant")
self.assertEqual(params["messages"][1]["content"], "Hi there!")
# Later system message should be converted to user role
self.assertEqual(params["messages"][2]["role"], "user")
self.assertEqual(params["messages"][2]["content"], "Remember to be concise.")
def test_single_system_message_converted_to_user(self):
"""Test that a single system message is converted to user role when no other messages exist."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# System should be NOT_GIVEN since we only have one message
from anthropic import NOT_GIVEN
self.assertEqual(params["system"], NOT_GIVEN)
# Single system message should be converted to user role
self.assertEqual(len(params["messages"]), 1)
self.assertEqual(params["messages"][0]["role"], "user")
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = AWSBedrockLLMAdapter()
def test_standard_messages_converted_to_aws_bedrock_format(self):
"""Test that LLMStandardMessage objects are converted to AWS Bedrock format."""
# Create standard messages
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
]
# Create context
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction is extracted (in AWS Bedrock format)
self.assertIsInstance(params["system"], list)
self.assertEqual(len(params["system"]), 1)
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
# Verify messages are in the params (2 messages after system extraction)
self.assertIn("messages", params)
self.assertEqual(len(params["messages"]), 2)
# Check first message (user) - should be converted to AWS Bedrock format
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 1)
self.assertEqual(user_msg["content"][0]["text"], "Hello, how are you?")
# Check second message (assistant) - should be converted to AWS Bedrock format
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(len(assistant_msg["content"]), 1)
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
def test_llm_specific_message_filtering(self):
"""Test that AWS-specific messages are included and others are filtered out."""
# Create aws-specific message content (which is what AWS Bedrock uses)
aws_message_content = {
"role": "user",
"content": [
{"text": "Hello"},
{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}},
],
}
messages = [
{"role": "user", "content": "Standard message"},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific"}
),
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Google specific"}
),
self.adapter.create_llm_specific_message(message=aws_message_content),
{"role": "assistant", "content": "Response"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
# (openai and google specific filtered out, standard + aws-specific merged)
self.assertEqual(len(params["messages"]), 2)
# First message: merged user message (standard + aws-specific)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
# Should have 3 content blocks: standard text + aws text + aws image
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
self.assertEqual(user_msg["content"][1]["text"], "Hello")
self.assertIn("image", user_msg["content"][2])
# Second message: standard response
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
def test_consecutive_same_role_messages_merged(self):
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
messages = [
{"role": "user", "content": "First user message"},
{"role": "user", "content": "Second user message"},
{"role": "user", "content": "Third user message"},
{"role": "assistant", "content": "First assistant message"},
{"role": "assistant", "content": "Second assistant message"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should have 2 messages after merging (1 user, 1 assistant)
self.assertEqual(len(params["messages"]), 2)
# Check merged user message
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["text"], "First user message")
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
# Check merged assistant message
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(len(assistant_msg["content"]), 2)
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
def test_empty_text_converted_to_empty_placeholder(self):
"""Test that empty text content is converted to "(empty)" string."""
messages = [
{"role": "user", "content": ""}, # Empty string
{
"role": "assistant",
"content": [
{"type": "text", "text": ""}, # Empty text in list content
{"type": "text", "text": "Valid text"},
],
},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Check that empty string content was converted
user_msg = params["messages"][0]
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(user_msg["content"][0]["text"], "(empty)")
# Check that empty text in list content was converted
assistant_msg = params["messages"][1]
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
def test_complex_message_content_preserved(self):
"""Test that complex message structures (text + image) are properly converted to AWS Bedrock format."""
# Create a complex message with both text and image content
# Use a valid base64 string for the image
complex_message = {
"role": "user",
"content": [
{"type": "text", "text": "What do you see in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
},
},
{"type": "text", "text": "Please describe it in detail."},
],
}
messages = [
complex_message,
{"role": "assistant", "content": "I can see the image clearly."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify complex message structure is preserved and converted
self.assertEqual(len(params["messages"]), 2)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
# Note: AWS Bedrock adapter reorders single images to come before text, like Anthropic
# Check image part (should be moved to first position and converted from image_url to image)
self.assertIn("image", user_msg["content"][0])
self.assertEqual(user_msg["content"][0]["image"]["format"], "jpeg")
self.assertIn("source", user_msg["content"][0]["image"])
self.assertIn("bytes", user_msg["content"][0]["image"]["source"])
# Check first text part (moved to second position)
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
# Check second text part (moved to third position)
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
def test_multiple_system_instructions_handling(self):
"""Test that first system instruction is extracted, later ones converted to user messages."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."}, # Later system message
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# System instruction should be extracted from first message (in AWS Bedrock format)
self.assertIsInstance(params["system"], list)
self.assertEqual(len(params["system"]), 1)
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
# Should have 3 messages remaining (system message was removed, later system converted to user)
self.assertEqual(len(params["messages"]), 3)
self.assertEqual(params["messages"][0]["role"], "user")
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
self.assertEqual(params["messages"][1]["role"], "assistant")
self.assertEqual(params["messages"][1]["content"][0]["text"], "Hi there!")
# Later system message should be converted to user role
self.assertEqual(params["messages"][2]["role"], "user")
self.assertEqual(params["messages"][2]["content"][0]["text"], "Remember to be concise.")
def test_single_system_message_handling(self):
"""Test that a single system message is extracted as system parameter and no messages remain."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# System should be extracted (in AWS Bedrock format)
self.assertIsInstance(params["system"], list)
self.assertEqual(len(params["system"]), 1)
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
# No messages should remain after system extraction
self.assertEqual(len(params["messages"]), 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -7,10 +7,10 @@
import unittest
from pipecat.frames.frames import (
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartInterruptionFrame,
)
from pipecat.processors.aggregators.llm_response import LLMFullResponseAggregator
from pipecat.tests.utils import SleepFrame, run_test
@@ -113,7 +113,7 @@ class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
LLMFullResponseStartFrame(),
LLMTextFrame("Hello "),
SleepFrame(),
InterruptionFrame(),
StartInterruptionFrame(),
LLMFullResponseStartFrame(),
LLMTextFrame("Hello "),
LLMTextFrame("there!"),
@@ -122,7 +122,7 @@ class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
expected_down_frames = [
LLMFullResponseStartFrame,
LLMTextFrame,
InterruptionFrame,
StartInterruptionFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
LLMTextFrame,

View File

@@ -65,7 +65,7 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
ignore_start=False,
pipeline_params=PipelineParams(start_metadata={"foo": "bar"}),
start_metadata={"foo": "bar"},
)
assert "foo" in received_down[-1].metadata
@@ -196,10 +196,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
nonlocal start_received
start_received = True
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task, frame: Frame):
@task.event_handler("on_pipeline_ended")
async def on_pipeline_ended(task, frame: EndFrame):
nonlocal end_received
end_received = isinstance(frame, EndFrame)
end_received = True
await task.queue_frame(EndFrame())
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
@@ -214,10 +214,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
pipeline = Pipeline([identity])
task = PipelineTask(pipeline)
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task, frame: Frame):
@task.event_handler("on_pipeline_stopped")
async def on_pipeline_ended(task, frame: StopFrame):
nonlocal stop_received
stop_received = isinstance(frame, StopFrame)
stop_received = True
await task.queue_frame(StopFrame())
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
@@ -441,10 +441,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
async def on_pipeline_started(task: PipelineTask, frame: StartFrame):
await task.cancel()
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
@task.event_handler("on_pipeline_cancelled")
async def on_pipeline_cancelled(task: PipelineTask, frame: CancelFrame):
nonlocal cancelled
cancelled = isinstance(frame, CancelFrame)
cancelled = True
try:
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))

View File

@@ -1,261 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from anthropic import NOT_GIVEN
from openai import NotGiven
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
@pytest.mark.asyncio
async def test_openai_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response."""
# Create service with mocked client
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
]
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello! How can I help you today?"
service._client.chat.completions.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.chat.completions.create.assert_called_once_with(
model="gpt-4",
messages=test_messages,
stream=False,
)
@pytest.mark.asyncio
async def test_openai_run_inference_client_exception():
"""Test that exceptions from the client are propagated."""
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.chat.completions.create.side_effect = Exception("API Error")
with pytest.raises(Exception, match="API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_anthropic_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Anthropic."""
# Create service with mocked client
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=test_messages, system=test_system, tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Hello! How can I help you today?"
service._client.messages.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(
mock_context, enable_prompt_caching=False
)
service._client.messages.create.assert_called_once_with(
model="claude-3-sonnet-20240229",
messages=test_messages,
system=test_system,
max_tokens=8192,
stream=False,
)
@pytest.mark.asyncio
async def test_anthropic_run_inference_client_exception():
"""Test that exceptions from the Anthropic client are propagated."""
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=[], system="Test system", tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.messages.create.side_effect = Exception("Anthropic API Error")
with pytest.raises(Exception, match="Anthropic API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_google_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Google."""
# Create service with mocked client
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=test_messages, system_instruction=test_system, tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content = MagicMock()
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_google_run_inference_client_exception():
"""Test that exceptions from the Google client are propagated."""
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=[], system_instruction="Test system", tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(
side_effect=Exception("Google API Error")
)
with pytest.raises(Exception, match="Google API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
# Create service and patch the session client method
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
test_system = [{"text": "You are a helpful assistant"}]
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=test_messages, system=test_system, tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock the client and response
mock_client = AsyncMock()
mock_response = {
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
}
mock_client.converse.return_value = mock_response
# Patch the _aws_session.client method to be an async context manager
async def mock_client_cm(*args, **kwargs):
return mock_client
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
mock_client.converse.assert_called_once()
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_client_exception():
"""Test that exceptions from the AWS Bedrock client are propagated."""
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock AWS client to raise exception
mock_client = AsyncMock()
mock_client.converse.side_effect = Exception("Bedrock API Error")
# Patch the _aws_session.client method to be an async context manager
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
with pytest.raises(Exception, match="Bedrock API Error"):
await service.run_inference(mock_context)

View File

@@ -1,303 +0,0 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Unit tests for ServiceSwitcher and related components."""
import unittest
from pipecat.frames.frames import (
Frame,
ManuallySwitchServiceFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import run_test
class MockFrameProcessor(FrameProcessor):
"""A test frame processor that tracks which frames it has processed."""
def __init__(self, test_name: str, **kwargs):
"""Initialize the test processor with a name.
Args:
test_name: A unique name for this processor instance.
**kwargs: Additional arguments passed to the parent FrameProcessor.
"""
super().__init__(name=test_name, **kwargs)
self.test_name = test_name
self.processed_frames = []
self.frame_count = 0
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process an incoming frame and track it.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
self.processed_frames.append(frame)
self.frame_count += 1
await self.push_frame(frame, direction)
def reset_counters(self):
"""Reset the frame tracking counters."""
self.processed_frames = []
self.frame_count = 0
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcherStrategyManual."""
def setUp(self):
"""Set up test fixtures."""
self.service1 = MockFrameProcessor("service1")
self.service2 = MockFrameProcessor("service2")
self.service3 = MockFrameProcessor("service3")
self.services = [self.service1, self.service2, self.service3]
def test_init_with_services(self):
"""Test initialization with a list of services."""
strategy = ServiceSwitcherStrategyManual(self.services)
self.assertEqual(strategy.services, self.services)
self.assertEqual(strategy.active_service, self.service1) # First service should be active
def test_init_with_empty_services(self):
"""Test initialization with an empty list of services."""
strategy = ServiceSwitcherStrategyManual([])
self.assertEqual(strategy.services, [])
self.assertIsNone(strategy.active_service)
def test_handle_manually_switch_service_frame(self):
"""Test manual service switching with ManuallySwitchServiceFrame."""
strategy = ServiceSwitcherStrategyManual(self.services)
# Initially service1 should be active
self.assertEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
# Switch to service2
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertEqual(strategy.active_service, self.service2)
self.assertNotEqual(strategy.active_service, self.service3)
# Switch to service3
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
self.assertEqual(strategy.active_service, self.service3)
def test_handle_frame_unsupported_frame_type(self):
"""Test that unsupported frame types raise an error."""
strategy = ServiceSwitcherStrategyManual(self.services)
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
with self.assertRaises(ValueError) as context:
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
self.assertIn("Unsupported frame type", str(context.exception))
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcher."""
def setUp(self):
"""Set up test fixtures."""
self.service1 = MockFrameProcessor("service1")
self.service2 = MockFrameProcessor("service2")
self.service3 = MockFrameProcessor("service3")
self.services = [self.service1, self.service2, self.service3]
def test_init_with_manual_strategy(self):
"""Test initialization with manual strategy."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
self.assertEqual(switcher.services, self.services)
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
self.assertEqual(switcher.strategy.services, self.services)
async def test_default_active_service(self):
"""Test that the initially-active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
# Reset counters
for service in self.services:
service.reset_counters()
# Send some test frames
frames_to_send = [
TextFrame(text="Hello 1"),
TextFrame(text="Hello 2"),
TextFrame(text="Hello 3"),
]
await run_test(
switcher,
frames_to_send=frames_to_send,
expected_down_frames=[TextFrame, TextFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
# Only service1 should have processed the text frames
# Note: The service also receives StartFrame and EndFrame, so count those too
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
self.assertEqual(len(text_frames), 3)
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
service3_text_frames = [
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
]
self.assertEqual(len(service2_text_frames), 0)
self.assertEqual(len(service3_text_frames), 0)
# Verify the actual text frames processed
for i, frame in enumerate(text_frames):
self.assertEqual(frame.text, f"Hello {i + 1}")
async def test_service_switching(self):
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
# Reset counters
for service in self.services:
service.reset_counters()
# Send a test frame, a switch frame, and another test frame
await run_test(
switcher,
frames_to_send=[
TextFrame("Hello 1"),
ManuallySwitchServiceFrame(service=self.service2),
TextFrame("Hello 2"),
],
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
# Verify service2 received the frame
service1_text_frames = [
f for f in self.service1.processed_frames if isinstance(f, TextFrame)
]
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
service3_text_frames = [
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
]
self.assertEqual(len(service1_text_frames), 1)
self.assertEqual(len(service2_text_frames), 1)
self.assertEqual(len(service3_text_frames), 0)
self.assertEqual(service1_text_frames[0].text, "Hello 1")
self.assertEqual(service2_text_frames[0].text, "Hello 2")
async def test_multi_service_switcher_targeting(self):
"""Test that ManuallySwitchServiceFrame targets the correct ServiceSwitcher in a multi-switcher pipeline."""
# Create services for first switcher
switcher1_service1 = MockFrameProcessor("switcher1_service1")
switcher1_service2 = MockFrameProcessor("switcher1_service2")
switcher1_services = [switcher1_service1, switcher1_service2]
# Create services for second switcher
switcher2_service1 = MockFrameProcessor("switcher2_service1")
switcher2_service2 = MockFrameProcessor("switcher2_service2")
switcher2_services = [switcher2_service1, switcher2_service2]
# Create two service switchers
switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual)
switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual)
# Create a pipeline with both switchers: switcher1 -> switcher2
pipeline = Pipeline([switcher1, switcher2])
# Reset counters
for service in switcher1_services + switcher2_services:
service.reset_counters()
# Initially, both switchers should use their first services
self.assertEqual(switcher1.strategy.active_service, switcher1_service1)
self.assertEqual(switcher2.strategy.active_service, switcher2_service1)
# Send frames to test the pipeline:
# 1. Text frame (should go through both switchers' active services)
# 2. Switch frame targeting switcher1's second service
# 3. Text frame (should go through switcher1's new service and switcher2's original service)
# 4. Switch frame targeting switcher2's second service
# 5. Text frame (should go through switcher1's current service and switcher2's new service)
await run_test(
pipeline,
frames_to_send=[
TextFrame("Before any switches"),
ManuallySwitchServiceFrame(service=switcher1_service2), # Switch first switcher
TextFrame("After switching first switcher"),
ManuallySwitchServiceFrame(service=switcher2_service2), # Switch second switcher
TextFrame("After switching second switcher"),
],
expected_down_frames=[
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
],
expected_up_frames=[], # Expect no error frames
)
# Verify the active services changed correctly
self.assertEqual(switcher1.strategy.active_service, switcher1_service2)
self.assertEqual(switcher2.strategy.active_service, switcher2_service2)
# Verify frame distribution:
# First text frame should go through switcher1_service1 and switcher2_service1
switcher1_service1_texts = [
f for f in switcher1_service1.processed_frames if isinstance(f, TextFrame)
]
switcher2_service1_texts = [
f for f in switcher2_service1.processed_frames if isinstance(f, TextFrame)
]
# Second text frame should go through switcher1_service2 and switcher2_service1
switcher1_service2_texts = [
f for f in switcher1_service2.processed_frames if isinstance(f, TextFrame)
]
# Third text frame should go through switcher1_service2 and switcher2_service2
switcher2_service2_texts = [
f for f in switcher2_service2.processed_frames if isinstance(f, TextFrame)
]
# Verify frame counts and content
self.assertEqual(len(switcher1_service1_texts), 1)
self.assertEqual(switcher1_service1_texts[0].text, "Before any switches")
self.assertEqual(len(switcher1_service2_texts), 2)
self.assertEqual(switcher1_service2_texts[0].text, "After switching first switcher")
self.assertEqual(switcher1_service2_texts[1].text, "After switching second switcher")
self.assertEqual(len(switcher2_service1_texts), 2)
self.assertEqual(switcher2_service1_texts[0].text, "Before any switches")
self.assertEqual(switcher2_service1_texts[1].text, "After switching first switcher")
self.assertEqual(len(switcher2_service2_texts), 1)
self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher")
if __name__ == "__main__":
unittest.main()

View File

@@ -14,7 +14,7 @@ from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
InterruptionFrame,
StartInterruptionFrame,
TranscriptionFrame,
TranscriptionMessage,
TranscriptionUpdateFrame,
@@ -238,7 +238,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
SleepFrame(),
InterruptionFrame(), # User interrupts here
StartInterruptionFrame(), # User interrupts here
SleepFrame(),
BotStartedSpeakingFrame(),
TTSTextFrame(text="New"),
@@ -252,7 +252,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
BotStartedSpeakingFrame,
TTSTextFrame, # "Hello"
TTSTextFrame, # "world!"
InterruptionFrame,
StartInterruptionFrame,
TranscriptionUpdateFrame, # First message (emitted due to interruption)
BotStartedSpeakingFrame,
TTSTextFrame, # "New"

Some files were not shown because too many files have changed in this diff Show More