Compare commits
83 Commits
v0.0.85
...
cb/frame-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05863ded53 | ||
|
|
6ab4a48d8f | ||
|
|
89e0092159 | ||
|
|
0de31dab79 | ||
|
|
10ff93307d | ||
|
|
414c9e3bc8 | ||
|
|
ba64f126a3 | ||
|
|
d28e3881a7 | ||
|
|
7df7395dd1 | ||
|
|
0885bc9cdf | ||
|
|
0204f6a95d | ||
|
|
b0bf653f04 | ||
|
|
e8a676eb36 | ||
|
|
ca96eef1f3 | ||
|
|
8e1637d6c7 | ||
|
|
367200c0ad | ||
|
|
766e1948a6 | ||
|
|
f369683b8b | ||
|
|
461025d1cc | ||
|
|
ac88706f38 | ||
|
|
93a89449b8 | ||
|
|
199bf72945 | ||
|
|
d20e4125f6 | ||
|
|
c1baed642e | ||
|
|
33ef68573f | ||
|
|
3c1b41df13 | ||
|
|
fca4ecc73c | ||
|
|
cfa333508b | ||
|
|
9e7260393a | ||
|
|
073b585c52 | ||
|
|
81c2e51bec | ||
|
|
42344125b1 | ||
|
|
db5bcfaa51 | ||
|
|
615239b7d2 | ||
|
|
27f1e9dd69 | ||
|
|
bd760deff2 | ||
|
|
8bc3c89140 | ||
|
|
2cd2567a37 | ||
|
|
5b55988846 | ||
|
|
a12392182c | ||
|
|
b814b70e1e | ||
|
|
a1f84e1b50 | ||
|
|
0839b48da8 | ||
|
|
de51637b77 | ||
|
|
e1b1dc16ec | ||
|
|
1fe27eb0a2 | ||
|
|
d7e1389497 | ||
|
|
8c7230aa8f | ||
|
|
2cf71239b0 | ||
|
|
ec2c62e32b | ||
|
|
38ce85e9a0 | ||
|
|
2279e5a899 | ||
|
|
cce6eb5d87 | ||
|
|
c2b98ae557 | ||
|
|
727eb12b16 | ||
|
|
ba96bd05d3 | ||
|
|
8ead309f8d | ||
|
|
fad0e55c64 | ||
|
|
74b1af56a0 | ||
|
|
6924850ec4 | ||
|
|
dfe7815dc5 | ||
|
|
69f0a75882 | ||
|
|
cca90791c4 | ||
|
|
f2a5d408de | ||
|
|
044c6eba46 | ||
|
|
db71089f5e | ||
|
|
f861f5066f | ||
|
|
81cede2c60 | ||
|
|
7603203230 | ||
|
|
8569b61598 | ||
|
|
fe42187dc1 | ||
|
|
999e88c942 | ||
|
|
c04df2f28b | ||
|
|
100ef0ab5c | ||
|
|
42886d7105 | ||
|
|
22cbba002a | ||
|
|
0a043154f2 | ||
|
|
5e322eba9e | ||
|
|
11d0c3d46d | ||
|
|
c873798ce5 | ||
|
|
95f72f6dce | ||
|
|
786387722a | ||
|
|
9f82c6b4a4 |
100
CHANGELOG.md
100
CHANGELOG.md
@@ -5,6 +5,106 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_before_disconnect` synchronous event to `DailyTransport` and
|
||||
`LiveKitTransport`.
|
||||
|
||||
- It is now possible to register synchronous event handlers. By default, all
|
||||
event handlers are executed in a separate task. However, in some cases we want
|
||||
to guarantee order of execution, for example, executing something before
|
||||
disconnecting a transport.
|
||||
|
||||
```python
|
||||
self._register_event_handler("on_event_name", sync=True)
|
||||
```
|
||||
|
||||
- Added support for global location in `GoogleVertexLLMService`. The service now
|
||||
supports both regional locations (e.g., "us-east4") and the "global" location
|
||||
for Vertex AI endpoints. When using "global" location, the service will use
|
||||
`aiplatform.googleapis.com` as the API host instead of the regional format.
|
||||
|
||||
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
|
||||
fired when the pipeline is done running. This can be the result of a
|
||||
`StopFrame`, `CancelFrame` or `EndFrame`.
|
||||
|
||||
```python
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
...
|
||||
```
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Silero VAD model to v6.
|
||||
|
||||
- Updated `livekit` to 1.0.13.
|
||||
|
||||
- `torch` and `torchaudio` are no longer required for running Smart Turn
|
||||
locally. This avoids gigabytes of dependencies being installed.
|
||||
|
||||
- Updated `websockets` dependency to support version 15.0. Removed deprecated
|
||||
usage of `ConnectionClosed.code` and `ConnectionClosed.reason` attributes in
|
||||
`AWSTranscribeSTTService` for compatibility.
|
||||
|
||||
- Refactored `pyproject.toml` to reduce websockets dependency repetition using
|
||||
self-referencing extras. All websockets-dependent services now reference a
|
||||
shared `websockets-base` extra.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `GladiaSTTService`'s `confidence` arg is deprecated. `confidence` is no
|
||||
longer needed to determine which transcription or translation frames to
|
||||
emit.
|
||||
|
||||
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
|
||||
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
|
||||
instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where multiple handlers for an event would not run in parallel.
|
||||
|
||||
- Fixed `DailyTransport.sip_call_transfer()` to automatically use the session
|
||||
ID from the `on_dialin_connected` event, when not explicitly provided. Now
|
||||
supports cold transfers (from incoming dial-in calls) by automatically
|
||||
tracking session IDs from connection events.
|
||||
|
||||
- Fixed a memory leak in `SmallWebRTCTransport`. In `aiortc`, when you receive
|
||||
a `MediaStreamTrack` (audio or video), frames are produced asynchronously. If
|
||||
the code never consumes these frames, they are queued in memory, causing a
|
||||
memory leak.
|
||||
|
||||
- Fixed an issue in `AsyncAITTSService`, where `TTSTextFrames` were not being
|
||||
pushed.
|
||||
|
||||
- Fixed an issue that would cause `push_interruption_task_frame_and_wait()` to
|
||||
not wait if a previous interruption had already happened.
|
||||
|
||||
- Fixed a couple of bugs in `ServiceSwitcher`:
|
||||
|
||||
- Using multiple `ServiceSwitcher`s in a pipeline would result in an error.
|
||||
- `ServiceSwitcherFrame`s (such as `ManuallySwitchServiceFrame`s) were having
|
||||
an effect too early, essentially "jumping the queue" in terms of pipeline
|
||||
frame ordering.
|
||||
|
||||
- Fixed a self-cancellation deadlock in `UserIdleProcessor` when returning
|
||||
`False` from an idle callback. The task now terminates naturally instead of
|
||||
attempting to cancel itself.
|
||||
|
||||
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
|
||||
when a bot speaks and user input is blocked.
|
||||
|
||||
- Fixed a `FastAPIWebsocketTransport` and `SmallWebRTCTransport` issue where
|
||||
`on_client_disconnected` would be triggered when the bot ends the
|
||||
conversation. That is, `on_client_disconnected` should only be triggered when
|
||||
the remote client actually disconnects.
|
||||
|
||||
- Fixed an issue in `HeyGenVideoService` where the `BotStartedSpeakingFrame`
|
||||
was blocked from moving through the Pipeline.
|
||||
|
||||
## [0.0.85] - 2025-09-12
|
||||
|
||||
### Added
|
||||
|
||||
@@ -21,6 +21,8 @@
|
||||
|
||||
🧭 Looking to build structured conversations? Check out [Pipecat Flows](https://github.com/pipecat-ai/pipecat-flows) for managing complex conversational states and transitions.
|
||||
|
||||
🔍 Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
|
||||
|
||||
## 🧠 Why Pipecat?
|
||||
|
||||
- **Voice-first**: Integrates speech recognition, text-to-speech, and conversation handling
|
||||
|
||||
@@ -11,7 +11,7 @@ import sys
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
@@ -50,7 +50,7 @@ async def main():
|
||||
async def on_first_participant_joined(transport, participant_id):
|
||||
await asyncio.sleep(1)
|
||||
await task.queue_frame(
|
||||
TextFrame(
|
||||
TTSSpeakFrame(
|
||||
"Hello there! How are you doing today? Would you like to talk about the weather?"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,10 +30,6 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
|
||||
# to the Smart Turn v3 ONNX model file.
|
||||
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -42,25 +38,19 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -47,32 +47,32 @@ Website = "https://pipecat.ai"
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.0.1" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "websockets>=13.1,<15.0" ]
|
||||
asyncai = [ "websockets>=13.1,<15.0" ]
|
||||
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.19.9" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "websockets>=13.1,<15.0" ]
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
|
||||
gladia = [ "websockets>=13.1,<15.0" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "pipecat-ai[websockets-base]" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
heygen = [ "livekit>=0.22.0", "websockets>=13.1,<15.0" ]
|
||||
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
|
||||
inworld = []
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "websockets>=13.1,<15.0" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "pipecat-ai[websockets-base]" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
@@ -80,34 +80,35 @@ mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
nim = []
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "websockets>=13.1,<15.0" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "websockets>=13.1,<15.0" ]
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
rime = [ "websockets>=13.1,<15.0" ]
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sambanova = []
|
||||
sarvam = [ "websockets>=13.1,<15.0" ]
|
||||
sarvam = [ "pipecat-ai[websockets-base]" ]
|
||||
sentry = [ "sentry-sdk~=2.23.1" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3", "onnxruntime>=1.20.1, <2" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1, <2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime>=1.20.1, <2" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soniox = [ "websockets>=13.1,<15.0" ]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.4.0" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
|
||||
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
|
||||
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
|
||||
webrtc = [ "aiortc~=1.13.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.117.0" ]
|
||||
websockets-base = [ "websockets>=13.1,<16.0" ]
|
||||
whisper = [ "faster-whisper~=1.1.1" ]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
12
scripts/mem-watch.sh
Executable file
12
scripts/mem-watch.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
PID=$1
|
||||
|
||||
while true; do
|
||||
# Clear the screen
|
||||
clear
|
||||
# Print the header + RSS in GB
|
||||
ps -p "$PID" -o pid,comm,rss | \
|
||||
awk 'NR==1 {print $0, "rss_GB"} NR>1 {printf "%s %s %s %.2f\n", $1,$2,$3,$3/1024/1024}'
|
||||
sleep 1
|
||||
done
|
||||
@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
# Should be a TypedDict
|
||||
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
|
||||
@@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
Subclasses must implement provider-specific conversion logic.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
|
||||
|
||||
Returns:
|
||||
The identifier string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
@@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
|
||||
|
||||
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
"""Get messages from the LLM context, including standard and LLM-specific messages.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages including standard and LLM-specific messages.
|
||||
"""
|
||||
return context.get_messages(self.id_for_llm_specific_messages)
|
||||
|
||||
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
|
||||
"""Convert tools from standard format to provider format.
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
to the specific format required by Anthropic's Claude models for function calling.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
|
||||
return "anthropic"
|
||||
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, enable_prompt_caching: bool
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
@@ -54,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": (
|
||||
@@ -78,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Anthropic.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -92,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
specific function-calling format, enabling tool use with Nova Sonic models.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
into AWS Bedrock's expected tool format for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
|
||||
return "aws"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
|
||||
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -51,7 +56,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Bedrock's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": messages.messages,
|
||||
@@ -75,7 +80,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about AWS Bedrock.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -89,9 +94,6 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Google."""
|
||||
return "google"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Gemini.
|
||||
"""
|
||||
# Get messages in Gemini's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
messages_for_logging.append(obj)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("google")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
@@ -47,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
|
||||
return "openai"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
|
||||
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -57,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
return {
|
||||
"messages": self._from_universal_context_messages(self._get_messages(context)),
|
||||
"messages": self._from_universal_context_messages(self.get_messages(context)),
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
"tool_choice": context.tool_choice,
|
||||
@@ -91,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
List of messages in a format ready for logging about OpenAI.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self._get_messages(context):
|
||||
for message in self.get_messages(context):
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
@@ -104,14 +110,18 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("openai")
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, messages: List[LLMContextMessage]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
# Just a pass-through: messages are already the right type
|
||||
return messages
|
||||
result = []
|
||||
for message in messages:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
# Extract the actual message content from LLMSpecificMessage
|
||||
result.append(message.message)
|
||||
else:
|
||||
# Standard message, pass through unchanged
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
def _from_standard_tool_choice(
|
||||
self, tool_choice: LLMContextToolChoice | NotGiven
|
||||
|
||||
@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
OpenAI's Realtime API for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
|
||||
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -98,15 +98,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
return_tensors="pt",
|
||||
return_tensors="np",
|
||||
padding="max_length",
|
||||
max_length=8 * 16000,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
|
||||
# Convert to numpy and ensure correct shape for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32)
|
||||
# Extract features and ensure correct shape for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).astype(np.float32)
|
||||
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
|
||||
|
||||
# Run ONNX inference
|
||||
|
||||
Binary file not shown.
@@ -1604,7 +1604,7 @@ class MixerEnableFrame(MixerControlFrame):
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFrame(ControlFrame):
|
||||
"""A base class for frames that control ServiceSwitcher behavior."""
|
||||
"""A base class for frames that affect ServiceSwitcher behavior."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -6,9 +6,15 @@
|
||||
|
||||
"""Service switcher for switching between different services at runtime, with different switching strategies."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, ServiceSwitcherFrame
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
ServiceSwitcherFrame,
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -22,19 +28,6 @@ class ServiceSwitcherStrategy:
|
||||
self.services = services
|
||||
self.active_service: Optional[FrameProcessor] = None
|
||||
|
||||
def is_active(self, service: FrameProcessor) -> bool:
|
||||
"""Determine if the given service is the currently active one.
|
||||
|
||||
This method should be overridden by subclasses to implement specific logic.
|
||||
|
||||
Args:
|
||||
service: The service to check.
|
||||
|
||||
Returns:
|
||||
True if the given service is the active one, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
@@ -60,17 +53,6 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
super().__init__(services)
|
||||
self.active_service = services[0] if services else None
|
||||
|
||||
def is_active(self, service: FrameProcessor) -> bool:
|
||||
"""Check if the given service is the currently active one.
|
||||
|
||||
Args:
|
||||
service: The service to check.
|
||||
|
||||
Returns:
|
||||
True if the given service is the active one, False otherwise.
|
||||
"""
|
||||
return service == self.active_service
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
@@ -79,20 +61,21 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
direction: The direction of the frame (upstream or downstream).
|
||||
"""
|
||||
if isinstance(frame, ManuallySwitchServiceFrame):
|
||||
self._set_active(frame.service)
|
||||
self._set_active_if_available(frame.service)
|
||||
else:
|
||||
raise ValueError(f"Unsupported frame type: {type(frame)}")
|
||||
|
||||
def _set_active(self, service: FrameProcessor):
|
||||
"""Set the active service to the given one.
|
||||
def _set_active_if_available(self, service: FrameProcessor):
|
||||
"""Set the active service to the given one, if it is in the list of available services.
|
||||
|
||||
If it's not in the list, the request is ignored, as it may have been
|
||||
intended for another ServiceSwitcher in the pipeline.
|
||||
|
||||
Args:
|
||||
service: The service to set as active.
|
||||
"""
|
||||
if service in self.services:
|
||||
self.active_service = service
|
||||
else:
|
||||
raise ValueError(f"Service {service} is not in the list of available services.")
|
||||
|
||||
|
||||
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
|
||||
@@ -108,6 +91,43 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
self.services = services
|
||||
self.strategy = strategy
|
||||
|
||||
class ServiceSwitcherFilter(FunctionFilter):
|
||||
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapped_service: FrameProcessor,
|
||||
active_service: FrameProcessor,
|
||||
direction: FrameDirection,
|
||||
):
|
||||
"""Initialize the service switcher filter with a strategy and direction."""
|
||||
|
||||
async def filter(_: Frame) -> bool:
|
||||
return self._wrapped_service == self._active_service
|
||||
|
||||
super().__init__(filter, direction)
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process a frame through the filter, handling special internal filter-updating frames."""
|
||||
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
|
||||
self._active_service = frame.active_service
|
||||
# Two ServiceSwitcherFilters "sandwich" a service. Push the
|
||||
# frame only to update the other side of the sandwich, but
|
||||
# otherwise don't let it leave the sandwich.
|
||||
if direction == self._direction:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFilterFrame(ControlFrame):
|
||||
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
|
||||
|
||||
active_service: FrameProcessor
|
||||
|
||||
@staticmethod
|
||||
def _make_pipeline_definitions(
|
||||
services: List[FrameProcessor], strategy: ServiceSwitcherStrategy
|
||||
@@ -121,14 +141,18 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
def _make_pipeline_definition(
|
||||
service: FrameProcessor, strategy: ServiceSwitcherStrategy
|
||||
) -> Any:
|
||||
async def filter(frame) -> bool:
|
||||
_ = frame
|
||||
return strategy.is_active(service)
|
||||
|
||||
return [
|
||||
FunctionFilter(filter, direction=FrameDirection.DOWNSTREAM),
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
direction=FrameDirection.DOWNSTREAM,
|
||||
),
|
||||
service,
|
||||
FunctionFilter(filter, direction=FrameDirection.UPSTREAM),
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
direction=FrameDirection.UPSTREAM,
|
||||
),
|
||||
]
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -142,3 +166,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
|
||||
if isinstance(frame, ServiceSwitcherFrame):
|
||||
self.strategy.handle_frame(frame, direction)
|
||||
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
|
||||
active_service=self.strategy.active_service
|
||||
)
|
||||
await super().process_frame(service_switcher_filter_frame, direction)
|
||||
|
||||
@@ -115,9 +115,28 @@ class PipelineTask(BasePipelineTask):
|
||||
- on_frame_reached_downstream: Called when downstream frames reach the sink
|
||||
- on_idle_timeout: Called when pipeline is idle beyond timeout threshold
|
||||
- on_pipeline_started: Called when pipeline starts with StartFrame
|
||||
- on_pipeline_stopped: Called when pipeline stops with StopFrame
|
||||
- on_pipeline_ended: Called when pipeline ends with EndFrame
|
||||
- on_pipeline_cancelled: Called when pipeline is cancelled
|
||||
- on_pipeline_stopped: [deprecated] Called when pipeline stops with StopFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_ended: [deprecated] Called when pipeline ends with EndFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_cancelled: [deprecated] Called when pipeline is cancelled with CancelFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_finished: Called after the pipeline has reached any terminal state.
|
||||
This includes:
|
||||
- StopFrame: pipeline was stopped (processors keep connections open)
|
||||
- EndFrame: pipeline ended normally
|
||||
- CancelFrame: pipeline was cancelled
|
||||
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
|
||||
the frame if they need to handle specific cases.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -128,6 +147,10 @@ class PipelineTask(BasePipelineTask):
|
||||
@task.event_handler("on_idle_timeout")
|
||||
async def on_pipeline_idle_timeout(task):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -264,6 +287,7 @@ class PipelineTask(BasePipelineTask):
|
||||
self._register_event_handler("on_pipeline_stopped")
|
||||
self._register_event_handler("on_pipeline_ended")
|
||||
self._register_event_handler("on_pipeline_cancelled")
|
||||
self._register_event_handler("on_pipeline_finished")
|
||||
|
||||
@property
|
||||
def params(self) -> PipelineParams:
|
||||
@@ -292,6 +316,27 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
return self._turn_trace_observer
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
"""Decorator for registering event handlers.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to handle.
|
||||
|
||||
Returns:
|
||||
The decorator function that registers the handler.
|
||||
"""
|
||||
if event_name in ["on_pipeline_stopped", "on_pipeline_ended", "on_pipeline_cancelled"]:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
f"Event '{event_name}' is deprecated, use 'on_pipeline_finished' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return super().event_handler(event_name)
|
||||
|
||||
def add_observer(self, observer: BaseObserver):
|
||||
"""Add an observer to monitor pipeline execution.
|
||||
|
||||
@@ -534,6 +579,7 @@ class PipelineTask(BasePipelineTask):
|
||||
)
|
||||
finally:
|
||||
await self._call_event_handler("on_pipeline_cancelled", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
|
||||
logger.debug(f"{self}: Closing. Waiting for {frame} to reach the end of the pipeline...")
|
||||
|
||||
@@ -681,9 +727,11 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_start_event.set()
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._call_event_handler("on_pipeline_ended", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, StopFrame):
|
||||
await self._call_event_handler("on_pipeline_stopped", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
|
||||
@@ -137,12 +137,12 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
return self._num_channels
|
||||
|
||||
def has_audio(self) -> bool:
|
||||
"""Check if both user and bot audio buffers contain data.
|
||||
"""Check if either user or bot audio buffers contain data.
|
||||
|
||||
Returns:
|
||||
True if both buffers contain audio data.
|
||||
True if either buffer contains audio data.
|
||||
"""
|
||||
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
|
||||
return self._buffer_has_audio(self._user_audio_buffer) or self._buffer_has_audio(
|
||||
self._bot_audio_buffer
|
||||
)
|
||||
|
||||
|
||||
@@ -220,6 +220,11 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_event: Optional[asyncio.Event] = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
|
||||
# Then we wait for the corresponding `InterruptionFrame` to travel from
|
||||
# the start of the pipeline back to the processor that sent the
|
||||
# `InterruptionTaskFrame`. This wait is handled using the following
|
||||
# event.
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
@@ -563,11 +568,17 @@ class FrameProcessor(BaseObject):
|
||||
"""Pause processing of queued frames."""
|
||||
logger.trace(f"{self}: pausing frame processing")
|
||||
self.__should_block_frames = True
|
||||
# We should also unset the process event here, in case it was set immediately after an interruption
|
||||
if self.__process_event:
|
||||
self.__process_event.clear()
|
||||
|
||||
async def pause_processing_system_frames(self):
|
||||
"""Pause processing of queued system frames."""
|
||||
logger.trace(f"{self}: pausing system frame processing")
|
||||
self.__should_block_system_frames = True
|
||||
# We should also unset the input event here, in case it was set immediately after an interruption
|
||||
if self.__input_event:
|
||||
self.__input_event.clear()
|
||||
|
||||
async def resume_processing_frames(self):
|
||||
"""Resume processing of queued frames."""
|
||||
@@ -632,7 +643,9 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
# If we are waiting for an interruption and we get an interruption, then
|
||||
# we can unblock `push_interruption_task_frame_and_wait()`.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
|
||||
@@ -17,7 +17,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -185,15 +184,13 @@ class UserIdleProcessor(FrameProcessor):
|
||||
|
||||
Runs in a loop until cancelled or callback indicates completion.
|
||||
"""
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._interrupted:
|
||||
self._retry_count += 1
|
||||
should_continue = await self._callback(self, self._retry_count)
|
||||
if not should_continue:
|
||||
await self._stop()
|
||||
break
|
||||
running = await self._callback(self, self._retry_count)
|
||||
finally:
|
||||
self._idle_event.clear()
|
||||
|
||||
@@ -70,7 +70,6 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -183,13 +182,14 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.request_handler import (
|
||||
SmallWebRTCRequest,
|
||||
SmallWebRTCRequestHandler,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"WebRTC transport dependencies not installed: {e}")
|
||||
return
|
||||
|
||||
# Store connections by pc_id
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
# Mount the frontend
|
||||
app.mount("/client", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -198,51 +198,33 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
"""Redirect root requests to client interface."""
|
||||
return RedirectResponse(url="/client/")
|
||||
|
||||
# Initialize the SmallWebRTC request handler
|
||||
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
|
||||
esp32_mode=esp32_mode, host=host
|
||||
)
|
||||
|
||||
@app.post("/api/offer")
|
||||
async def offer(request: dict, background_tasks: BackgroundTasks):
|
||||
"""Handle WebRTC offer requests and manage peer connections."""
|
||||
pc_id = request.get("pc_id")
|
||||
|
||||
if pc_id and pc_id in pcs_map:
|
||||
pipecat_connection = pcs_map[pc_id]
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
await pipecat_connection.renegotiate(
|
||||
sdp=request["sdp"],
|
||||
type=request["type"],
|
||||
restart_pc=request.get("restart_pc", False),
|
||||
)
|
||||
else:
|
||||
pipecat_connection = SmallWebRTCConnection()
|
||||
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
|
||||
|
||||
@pipecat_connection.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
"""Handle WebRTC connection closure and cleanup."""
|
||||
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
|
||||
pcs_map.pop(webrtc_connection.pc_id, None)
|
||||
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
|
||||
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
|
||||
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
answer = pipecat_connection.get_answer()
|
||||
|
||||
# Apply ESP32 SDP munging if enabled
|
||||
if esp32_mode and host != "localhost":
|
||||
from pipecat.runner.utils import smallwebrtc_sdp_munging
|
||||
|
||||
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], host)
|
||||
|
||||
pcs_map[answer["pc_id"]] = pipecat_connection
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
answer = await small_webrtc_handler.handle_web_request(
|
||||
request=request,
|
||||
webrtc_connection_callback=webrtc_connection_callback,
|
||||
)
|
||||
return answer
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage FastAPI application lifecycle and cleanup connections."""
|
||||
yield
|
||||
coros = [pc.disconnect() for pc in pcs_map.values()]
|
||||
await asyncio.gather(*coros)
|
||||
pcs_map.clear()
|
||||
await small_webrtc_handler.close()
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
|
||||
@@ -119,7 +119,6 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -811,60 +811,55 @@ class AWSBedrockLLMService(LLMService):
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
try:
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}]
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}]
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _create_converse_stream(self, client, request_params):
|
||||
|
||||
@@ -532,9 +532,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(
|
||||
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
|
||||
@@ -13,6 +13,7 @@ supporting multiple languages, custom vocabulary, and various audio processing o
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -173,8 +174,6 @@ class _InputParamsDescriptor:
|
||||
"""Descriptor for backward compatibility with deprecation warning."""
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
@@ -208,7 +207,7 @@ class GladiaSTTService(STTService):
|
||||
api_key: str,
|
||||
region: Literal["us-west", "eu-west"] | None = None,
|
||||
url: str = "https://api.gladia.io/v2/live",
|
||||
confidence: float = 0.5,
|
||||
confidence: Optional[float] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
model: str = "solaria-1",
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
@@ -224,6 +223,11 @@ class GladiaSTTService(STTService):
|
||||
region: Region used to process audio. eu-west or us-west. Defaults to eu-west.
|
||||
url: Gladia API URL. Defaults to "https://api.gladia.io/v2/live".
|
||||
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
The 'confidence' parameter is deprecated and will be removed in a future version.
|
||||
No confidence threshold is applied.
|
||||
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
model: Model to use for transcription. Defaults to "solaria-1".
|
||||
params: Additional configuration parameters for Gladia service.
|
||||
@@ -236,7 +240,6 @@ class GladiaSTTService(STTService):
|
||||
|
||||
params = params or GladiaInputParams()
|
||||
|
||||
# Warn about deprecated language parameter if it's used
|
||||
if params.language is not None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
@@ -247,11 +250,20 @@ class GladiaSTTService(STTService):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if confidence:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The 'confidence' parameter is deprecated and will be removed in a future version. "
|
||||
"No confidence threshold is applied.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._region = region
|
||||
self._url = url
|
||||
self.set_model_name(model)
|
||||
self._confidence = confidence
|
||||
self._params = params
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
@@ -575,43 +587,40 @@ class GladiaSTTService(STTService):
|
||||
|
||||
elif content["type"] == "transcript":
|
||||
utterance = content["data"]["utterance"]
|
||||
confidence = utterance.get("confidence", 0)
|
||||
language = utterance["language"]
|
||||
transcript = utterance["text"]
|
||||
is_final = content["data"]["is_final"]
|
||||
if confidence >= self._confidence:
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=is_final,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=is_final,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
elif content["type"] == "translation":
|
||||
translated_utterance = content["data"]["translated_utterance"]
|
||||
original_language = content["data"]["original_language"]
|
||||
translated_language = translated_utterance["language"]
|
||||
confidence = translated_utterance.get("confidence", 0)
|
||||
translation = translated_utterance["text"]
|
||||
if translated_language != original_language and confidence >= self._confidence:
|
||||
if translated_language != original_language:
|
||||
await self.push_frame(
|
||||
TranslationFrame(
|
||||
translation, "", time_now_iso8601(), translated_language
|
||||
|
||||
@@ -83,14 +83,23 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
self._api_key = self._get_api_token(credentials, credentials_path)
|
||||
|
||||
super().__init__(
|
||||
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
|
||||
api_key=self._api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_base_url(params: InputParams) -> str:
|
||||
"""Construct the base URL for Vertex AI API."""
|
||||
# Determine the correct API host based on location
|
||||
if params.location == "global":
|
||||
api_host = "aiplatform.googleapis.com"
|
||||
else:
|
||||
api_host = f"{params.location}-aiplatform.googleapis.com"
|
||||
return (
|
||||
f"https://{params.location}-aiplatform.googleapis.com/v1/"
|
||||
f"https://{api_host}/v1/"
|
||||
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
|
||||
)
|
||||
|
||||
@@ -118,12 +127,14 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
if credentials:
|
||||
# Parse and load credentials from JSON string
|
||||
creds = service_account.Credentials.from_service_account_info(
|
||||
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
json.loads(credentials),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
elif credentials_path:
|
||||
# Load credentials from JSON file
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
credentials_path,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -240,6 +240,7 @@ class HeyGenVideoService(AIService):
|
||||
# As soon as we receive actual audio, the base output transport will create a
|
||||
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -195,6 +195,17 @@ class LLMService(AIService):
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return self.get_llm_adapter().create_llm_specific_message(message)
|
||||
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -28,6 +28,8 @@ except ModuleNotFoundError as e:
|
||||
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
|
||||
|
||||
class MCPClient(BaseObject):
|
||||
"""Client for Model Context Protocol (MCP) servers.
|
||||
@@ -42,7 +44,7 @@ class MCPClient(BaseObject):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
|
||||
server_params: ServerParameters,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
@@ -25,6 +25,7 @@ from pydantic import BaseModel
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
ControlFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -41,6 +42,7 @@ from pipecat.frames.frames import (
|
||||
UserAudioRawFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
DataFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -105,6 +107,17 @@ class DailyInputTransportMessageUrgentFrame(InputTransportMessageUrgentFrame):
|
||||
participant_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyUpdateRemoteParticipantsFrame(ControlFrame):
|
||||
"""Frame to update remote participants in Daily calls.
|
||||
|
||||
Parameters:
|
||||
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
|
||||
"""
|
||||
|
||||
remote_participants: Mapping[str, Any] = None
|
||||
|
||||
|
||||
class WebRTCVADAnalyzer(VADAnalyzer):
|
||||
"""Voice Activity Detection analyzer using WebRTC.
|
||||
|
||||
@@ -215,6 +228,7 @@ class DailyCallbacks(BaseModel):
|
||||
on_active_speaker_changed: Called when the active speaker of the call has changed.
|
||||
on_joined: Called when bot successfully joined a room.
|
||||
on_left: Called when bot left a room.
|
||||
on_before_leave: Called when bot is about to leave the room.
|
||||
on_error: Called when an error occurs.
|
||||
on_app_message: Called when receiving an app message.
|
||||
on_call_state_updated: Called when call state changes.
|
||||
@@ -244,6 +258,7 @@ class DailyCallbacks(BaseModel):
|
||||
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_left: Callable[[], Awaitable[None]]
|
||||
on_before_leave: Callable[[], Awaitable[None]]
|
||||
on_error: Callable[[str], Awaitable[None]]
|
||||
on_app_message: Callable[[Any, str], Awaitable[None]]
|
||||
on_call_state_updated: Callable[[str], Awaitable[None]]
|
||||
@@ -359,6 +374,7 @@ class DailyTransportClient(EventHandler):
|
||||
self._transcription_ids = []
|
||||
self._transcription_status = None
|
||||
self._dial_out_session_id: str = ""
|
||||
self._dial_in_session_id: str = ""
|
||||
|
||||
self._joining = False
|
||||
self._joined = False
|
||||
@@ -719,6 +735,9 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
logger.info(f"Leaving {self._room_url}")
|
||||
|
||||
# Call callback before leaving.
|
||||
await self._callbacks.on_before_leave()
|
||||
|
||||
if self._params.transcription_enabled:
|
||||
await self.stop_transcription()
|
||||
|
||||
@@ -823,6 +842,16 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
settings: SIP call transfer settings.
|
||||
"""
|
||||
session_id = (
|
||||
settings.get("sessionId") or self._dial_out_session_id or self._dial_in_session_id
|
||||
)
|
||||
if not session_id:
|
||||
logger.error("Unable to transfer SIP call: 'sessionId' is not set")
|
||||
return
|
||||
|
||||
# Update 'sessionId' field.
|
||||
settings["sessionId"] = session_id
|
||||
|
||||
future = self._get_event_loop().create_future()
|
||||
self._client.sip_call_transfer(settings, completion=completion_callback(future))
|
||||
await future
|
||||
@@ -1141,6 +1170,7 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
data: Dial-in connection data.
|
||||
"""
|
||||
self._dial_in_session_id = data["sessionId"] if "sessionId" in data else ""
|
||||
self._call_event_callback(self._callbacks.on_dialin_connected, data)
|
||||
|
||||
def on_dialin_ready(self, sip_endpoint: str):
|
||||
@@ -1157,6 +1187,9 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
data: Dial-in stop data.
|
||||
"""
|
||||
# Cleanup only if our session stopped.
|
||||
if data.get("sessionId") == self._dial_in_session_id:
|
||||
self._dial_in_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialin_stopped, data)
|
||||
|
||||
def on_dialin_error(self, data: Any):
|
||||
@@ -1165,6 +1198,9 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
data: Dial-in error data.
|
||||
"""
|
||||
# Cleanup only if our session errored out.
|
||||
if data.get("sessionId") == self._dial_in_session_id:
|
||||
self._dial_in_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialin_error, data)
|
||||
|
||||
def on_dialin_warning(self, data: Any):
|
||||
@@ -1199,7 +1235,7 @@ class DailyTransportClient(EventHandler):
|
||||
data: Dial-out stop data.
|
||||
"""
|
||||
# Cleanup only if our session stopped.
|
||||
if data["sessionId"] == self._dial_out_session_id:
|
||||
if data.get("sessionId") == self._dial_out_session_id:
|
||||
self._dial_out_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialout_stopped, data)
|
||||
|
||||
@@ -1210,7 +1246,7 @@ class DailyTransportClient(EventHandler):
|
||||
data: Dial-out error data.
|
||||
"""
|
||||
# Cleanup only if our session errored out.
|
||||
if data["sessionId"] == self._dial_out_session_id:
|
||||
if data.get("sessionId") == self._dial_out_session_id:
|
||||
self._dial_out_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialout_error, data)
|
||||
|
||||
@@ -1767,6 +1803,31 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process outgoing frames, including transport messages.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
logger.debug(f"Got a DailyUpdateRemoteParticipantsFrame: {frame}")
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process outgoing frames, including transport messages.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
"""Send a transport message to participants.
|
||||
|
||||
@@ -1862,6 +1923,7 @@ class DailyTransport(BaseTransport):
|
||||
on_active_speaker_changed=self._on_active_speaker_changed,
|
||||
on_joined=self._on_joined,
|
||||
on_left=self._on_left,
|
||||
on_before_leave=self._on_before_leave,
|
||||
on_error=self._on_error,
|
||||
on_app_message=self._on_app_message,
|
||||
on_call_state_updated=self._on_call_state_updated,
|
||||
@@ -1925,6 +1987,10 @@ class DailyTransport(BaseTransport):
|
||||
self._register_event_handler("on_recording_started")
|
||||
self._register_event_handler("on_recording_stopped")
|
||||
self._register_event_handler("on_recording_error")
|
||||
self._register_event_handler("on_before_disconnect", sync=True)
|
||||
# Deprecated
|
||||
self._register_event_handler("on_joined")
|
||||
self._register_event_handler("on_left")
|
||||
|
||||
#
|
||||
# BaseTransport
|
||||
@@ -2176,6 +2242,10 @@ class DailyTransport(BaseTransport):
|
||||
"""Handle room left events."""
|
||||
await self._call_event_handler("on_left")
|
||||
|
||||
async def _on_before_leave(self):
|
||||
"""Handle before leave room events."""
|
||||
await self._call_event_handler("on_before_disconnect")
|
||||
|
||||
async def _on_error(self, error):
|
||||
"""Handle error events and push error frames."""
|
||||
await self._call_event_handler("on_error", error)
|
||||
@@ -2315,7 +2385,7 @@ class DailyTransport(BaseTransport):
|
||||
"""Handle participant updated events."""
|
||||
await self._call_event_handler("on_participant_updated", participant)
|
||||
|
||||
async def _on_transcription_message(self, message: Dict[str, Any]) -> None:
|
||||
async def _on_transcription_message(self, message: Mapping[str, Any]) -> None:
|
||||
"""Handle transcription message events."""
|
||||
await self._call_event_handler("on_transcription_message", message)
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ class LiveKitCallbacks(BaseModel):
|
||||
|
||||
on_connected: Callable[[], Awaitable[None]]
|
||||
on_disconnected: Callable[[], Awaitable[None]]
|
||||
on_before_disconnect: Callable[[], Awaitable[None]]
|
||||
on_participant_connected: Callable[[str], Awaitable[None]]
|
||||
on_participant_disconnected: Callable[[str], Awaitable[None]]
|
||||
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
|
||||
@@ -282,6 +283,7 @@ class LiveKitTransportClient:
|
||||
return
|
||||
|
||||
logger.info(f"Disconnecting from {self._room_name}")
|
||||
await self._callbacks.on_before_disconnect()
|
||||
await self.room.disconnect()
|
||||
self._connected = False
|
||||
logger.info(f"Disconnected from {self._room_name}")
|
||||
@@ -918,6 +920,7 @@ class LiveKitTransport(BaseTransport):
|
||||
callbacks = LiveKitCallbacks(
|
||||
on_connected=self._on_connected,
|
||||
on_disconnected=self._on_disconnected,
|
||||
on_before_disconnect=self._on_before_disconnect,
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
on_audio_track_subscribed=self._on_audio_track_subscribed,
|
||||
@@ -947,6 +950,7 @@ class LiveKitTransport(BaseTransport):
|
||||
self._register_event_handler("on_first_participant_joined")
|
||||
self._register_event_handler("on_participant_left")
|
||||
self._register_event_handler("on_call_state_updated")
|
||||
self._register_event_handler("on_before_disconnect", sync=True)
|
||||
|
||||
def input(self) -> LiveKitInputTransport:
|
||||
"""Get the input transport for receiving media and events.
|
||||
@@ -1041,6 +1045,10 @@ class LiveKitTransport(BaseTransport):
|
||||
"""Handle room disconnected events."""
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _on_before_disconnect(self):
|
||||
"""Handle before disconnection room events."""
|
||||
await self._call_event_handler("on_before_disconnect")
|
||||
|
||||
async def _on_participant_connected(self, participant_id: str):
|
||||
"""Handle participant connected events."""
|
||||
await self._call_event_handler("on_participant_connected", participant_id)
|
||||
|
||||
@@ -95,15 +95,20 @@ class SmallWebRTCTrack:
|
||||
enable/disable control and frame discarding for audio and video streams.
|
||||
"""
|
||||
|
||||
def __init__(self, track: MediaStreamTrack):
|
||||
def __init__(self, receiver):
|
||||
"""Initialize the WebRTC track wrapper.
|
||||
|
||||
Args:
|
||||
track: The underlying MediaStreamTrack to wrap.
|
||||
index: The index of the track in the transceiver (0 for mic, 1 for cam, 2 for screen)
|
||||
receiver: The RemoteStreamTrack receiver instance.
|
||||
"""
|
||||
self._track = track
|
||||
self._receiver = receiver
|
||||
# Configuring the receiver for not consuming the track by default to prevent memory grow
|
||||
self._receiver._enabled = False
|
||||
self._track = receiver.track
|
||||
self._enabled = True
|
||||
self._last_recv_time: float = 0.0
|
||||
self._idle_task: Optional[asyncio.Task] = None
|
||||
self._idle_timeout: float = 2.0 # seconds before discarding old frames
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
"""Enable or disable the track.
|
||||
@@ -138,13 +143,44 @@ class SmallWebRTCTrack:
|
||||
async def recv(self) -> Optional[Frame]:
|
||||
"""Receive the next frame from the track.
|
||||
|
||||
Enables the internal receiving state and starts idle watcher.
|
||||
|
||||
Returns:
|
||||
The next frame, except for video tracks, where it returns the frame only if the track is enabled, otherwise, returns None.
|
||||
"""
|
||||
self._receiver._enabled = True
|
||||
self._last_recv_time = time.time()
|
||||
|
||||
# start idle watcher if not already running
|
||||
if not self._idle_task or self._idle_task.done():
|
||||
self._idle_task = asyncio.create_task(self._idle_watcher())
|
||||
|
||||
if not self._enabled and self._track.kind == "video":
|
||||
return None
|
||||
return await self._track.recv()
|
||||
|
||||
async def _idle_watcher(self):
|
||||
"""Disable receiving if idle for more than _idle_timeout and monitor queue size."""
|
||||
while self._receiver._enabled:
|
||||
await asyncio.sleep(self._idle_timeout)
|
||||
idle_duration = time.time() - self._last_recv_time
|
||||
if idle_duration >= self._idle_timeout:
|
||||
# discard old frames to prevent memory growth
|
||||
logger.debug(
|
||||
f"Disabling receiver for {self._track.kind} track after {idle_duration:.2f}s idle"
|
||||
)
|
||||
await self.discard_old_frames()
|
||||
self._receiver._enabled = False
|
||||
|
||||
def stop(self):
|
||||
"""Stop receiving frames from the track."""
|
||||
self._receiver._enabled = False
|
||||
if self._idle_task:
|
||||
self._idle_task.cancel()
|
||||
self._idle_task = None
|
||||
if self._track:
|
||||
self._track.stop()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward attribute access to the underlying track.
|
||||
|
||||
@@ -454,6 +490,10 @@ class SmallWebRTCConnection(BaseObject):
|
||||
|
||||
async def _close(self):
|
||||
"""Close the peer connection and cleanup resources."""
|
||||
for track in self._track_map.values():
|
||||
if track:
|
||||
track.stop()
|
||||
self._track_map.clear()
|
||||
if self._pc:
|
||||
await self._pc.close()
|
||||
self._message_queue.clear()
|
||||
@@ -526,8 +566,8 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.warning("No audio transceiver is available")
|
||||
return None
|
||||
|
||||
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
|
||||
audio_track = SmallWebRTCTrack(track) if track else None
|
||||
receiver = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver
|
||||
audio_track = SmallWebRTCTrack(receiver) if receiver else None
|
||||
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
|
||||
return audio_track
|
||||
|
||||
@@ -548,8 +588,8 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.warning("No video transceiver is available")
|
||||
return None
|
||||
|
||||
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
|
||||
video_track = SmallWebRTCTrack(track) if track else None
|
||||
receiver = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver
|
||||
video_track = SmallWebRTCTrack(receiver) if receiver else None
|
||||
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
|
||||
return video_track
|
||||
|
||||
@@ -570,8 +610,8 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.warning("No screen video transceiver is available")
|
||||
return None
|
||||
|
||||
track = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver.track
|
||||
video_track = SmallWebRTCTrack(track) if track else None
|
||||
receiver = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver
|
||||
video_track = SmallWebRTCTrack(receiver) if receiver else None
|
||||
self._track_map[SCREEN_VIDEO_TRANSCEIVER_INDEX] = video_track
|
||||
return video_track
|
||||
|
||||
|
||||
200
src/pipecat/transports/smallwebrtc/request_handler.py
Normal file
200
src/pipecat/transports/smallwebrtc/request_handler.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""SmallWebRTC request handler for managing peer connections.
|
||||
|
||||
This module provides a client for handling web requests and managing WebRTC connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmallWebRTCRequest:
|
||||
"""Small WebRTC transport session arguments for the runner.
|
||||
|
||||
Parameters:
|
||||
sdp: The SDP string (Session Description Protocol).
|
||||
type: The type of the SDP, either "offer" or "answer".
|
||||
pc_id: Optional identifier for the peer connection.
|
||||
restart_pc: Optional whether to restart the peer connection.
|
||||
request_data: Optional custom data sent by the customer.
|
||||
"""
|
||||
|
||||
sdp: str
|
||||
type: str
|
||||
pc_id: Optional[str] = None
|
||||
restart_pc: Optional[bool] = None
|
||||
request_data: Optional[Any] = None
|
||||
|
||||
|
||||
class ConnectionMode(Enum):
|
||||
"""Enum defining the connection handling modes."""
|
||||
|
||||
SINGLE = "single" # Only one active connection allowed
|
||||
MULTIPLE = "multiple" # Multiple simultaneous connections allowed
|
||||
|
||||
|
||||
class SmallWebRTCRequestHandler:
|
||||
"""SmallWebRTC request handler for managing peer connections.
|
||||
|
||||
This class is responsible for:
|
||||
- Handling incoming SmallWebRTC requests.
|
||||
- Creating and managing WebRTC peer connections.
|
||||
- Supporting ESP32-specific SDP munging if enabled.
|
||||
- Invoking callbacks for newly initialized connections.
|
||||
- Supporting both single and multiple connection modes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ice_servers: Optional[List[IceServer]] = None,
|
||||
esp32_mode: bool = False,
|
||||
host: Optional[str] = None,
|
||||
connection_mode: ConnectionMode = ConnectionMode.MULTIPLE,
|
||||
) -> None:
|
||||
"""Initialize a SmallWebRTC request handler.
|
||||
|
||||
Args:
|
||||
ice_servers (Optional[List[IceServer]]): List of ICE servers to use for WebRTC
|
||||
connections.
|
||||
esp32_mode (bool): If True, enables ESP32-specific SDP munging.
|
||||
host (Optional[str]): Host address used for SDP munging in ESP32 mode.
|
||||
Ignored if `esp32_mode` is False.
|
||||
connection_mode (ConnectionMode): Mode of operation for handling connections.
|
||||
SINGLE allows only one active connection, MULTIPLE allows several.
|
||||
"""
|
||||
self._ice_servers = ice_servers
|
||||
self._esp32_mode = esp32_mode
|
||||
self._host = host
|
||||
self._connection_mode = connection_mode
|
||||
|
||||
# Store connections by pc_id
|
||||
self._pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
def _check_single_connection_constraints(self, pc_id: Optional[str]) -> None:
|
||||
"""Check if the connection request satisfies single connection mode constraints.
|
||||
|
||||
Args:
|
||||
pc_id: The peer connection ID from the request
|
||||
|
||||
Raises:
|
||||
HTTPException: If constraints are violated in single connection mode
|
||||
"""
|
||||
if self._connection_mode != ConnectionMode.SINGLE:
|
||||
return
|
||||
|
||||
if not self._pcs_map: # No existing connections
|
||||
return
|
||||
|
||||
# Get the existing connection (should be only one in single mode)
|
||||
existing_connection = next(iter(self._pcs_map.values()))
|
||||
|
||||
if existing_connection.pc_id != pc_id and pc_id:
|
||||
logger.warning(
|
||||
f"Connection pc_id mismatch: existing={existing_connection.pc_id}, received={pc_id}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="PC ID mismatch with existing connection")
|
||||
|
||||
if not pc_id:
|
||||
logger.warning(
|
||||
"Cannot create new connection: existing connection found but no pc_id received"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot create new connection with existing connection active",
|
||||
)
|
||||
|
||||
async def handle_web_request(
|
||||
self,
|
||||
request: SmallWebRTCRequest,
|
||||
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Handle a SmallWebRTC request and resolve the pending answer.
|
||||
|
||||
This method will:
|
||||
- Reuse an existing WebRTC connection if `pc_id` exists.
|
||||
- Otherwise, create a new `SmallWebRTCConnection`.
|
||||
- Invoke the provided callback with the connection.
|
||||
- Manage ESP32-specific munging if enabled.
|
||||
- Enforce single/multiple connection mode constraints.
|
||||
|
||||
Args:
|
||||
request (SmallWebRTCRequest): The incoming WebRTC request, containing
|
||||
SDP, type, and optionally a `pc_id`.
|
||||
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
|
||||
asynchronous callback function that is invoked with the WebRTC connection.
|
||||
|
||||
Raises:
|
||||
HTTPException: If connection mode constraints are violated
|
||||
Exception: Any exception raised during request handling or callback execution
|
||||
will be logged and propagated.
|
||||
"""
|
||||
try:
|
||||
pc_id = request.pc_id
|
||||
|
||||
# Check connection mode constraints first
|
||||
self._check_single_connection_constraints(pc_id)
|
||||
|
||||
# After constraints are satisfied, get the existing connection if any
|
||||
existing_connection = self._pcs_map.get(pc_id) if pc_id else None
|
||||
|
||||
if existing_connection:
|
||||
pipecat_connection = existing_connection
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
await pipecat_connection.renegotiate(
|
||||
sdp=request.sdp,
|
||||
type=request.type,
|
||||
restart_pc=request.restart_pc or False,
|
||||
)
|
||||
else:
|
||||
pipecat_connection = SmallWebRTCConnection(ice_servers=self._ice_servers)
|
||||
await pipecat_connection.initialize(sdp=request.sdp, type=request.type)
|
||||
|
||||
@pipecat_connection.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
|
||||
self._pcs_map.pop(webrtc_connection.pc_id, None)
|
||||
|
||||
# Invoke callback provided in runner arguments
|
||||
try:
|
||||
await webrtc_connection_callback(pipecat_connection)
|
||||
logger.debug(
|
||||
f"webrtc_connection_callback executed successfully for peer: {pipecat_connection.pc_id}"
|
||||
)
|
||||
except Exception as callback_error:
|
||||
logger.error(
|
||||
f"webrtc_connection_callback failed for peer {pipecat_connection.pc_id}: {callback_error}"
|
||||
)
|
||||
|
||||
answer = pipecat_connection.get_answer()
|
||||
|
||||
if self._esp32_mode and self._host and self._host != "localhost":
|
||||
from pipecat.runner.utils import smallwebrtc_sdp_munging
|
||||
|
||||
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], self._host)
|
||||
|
||||
self._pcs_map[answer["pc_id"]] = pipecat_connection
|
||||
|
||||
return answer
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SmallWebRTC request: {e}")
|
||||
logger.debug(f"SmallWebRTC request details: {request}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Clear the connection map."""
|
||||
coros = [pc.disconnect() for pc in self._pcs_map.values()]
|
||||
await asyncio.gather(*coros)
|
||||
self._pcs_map.clear()
|
||||
@@ -478,7 +478,11 @@ class SmallWebRTCClient:
|
||||
self._screen_video_track = None
|
||||
self._audio_output_track = None
|
||||
self._video_output_track = None
|
||||
await self._callbacks.on_client_disconnected(self._webrtc_connection)
|
||||
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._closing:
|
||||
await self._callbacks.on_client_disconnected(self._webrtc_connection)
|
||||
|
||||
async def _handle_app_message(self, message: Any):
|
||||
"""Handle incoming application messages."""
|
||||
|
||||
@@ -138,7 +138,6 @@ class FastAPIWebsocketClient:
|
||||
):
|
||||
logger.warning("Closing already disconnected websocket!")
|
||||
self._closing = True
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect the WebSocket client."""
|
||||
@@ -152,8 +151,6 @@ class FastAPIWebsocketClient:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception while closing the websocket: {e}")
|
||||
finally:
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def trigger_client_disconnected(self):
|
||||
"""Trigger the client disconnected callback."""
|
||||
@@ -298,7 +295,10 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
await self._client.trigger_client_disconnected()
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._client.is_closing:
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
@@ -446,6 +446,9 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
"""Serialize and send a frame through the WebSocket."""
|
||||
if self._client.is_closing or not self._client.is_connected:
|
||||
return
|
||||
|
||||
if not self._params.serializer:
|
||||
return
|
||||
|
||||
|
||||
@@ -14,13 +14,33 @@ and async cleanup for all Pipecat components.
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandler:
|
||||
"""Data class to store event handlers information.
|
||||
|
||||
This data class stores the event name, a list of handlers to run for this
|
||||
event, and whether these handlers will be executed in a task.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the event handler.
|
||||
handlers (List[Any]): A list of functions to be called when this event is triggered.
|
||||
is_sync (bool): Indicates whether the functions are executed in a task.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
handlers: List[Any]
|
||||
is_sync: bool
|
||||
|
||||
|
||||
class BaseObject(ABC):
|
||||
"""Abstract base class providing common functionality for Pipecat objects.
|
||||
|
||||
@@ -41,7 +61,7 @@ class BaseObject(ABC):
|
||||
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
|
||||
# Registered event handlers.
|
||||
self._event_handlers: dict = {}
|
||||
self._event_handlers: Dict[str, EventHandler] = {}
|
||||
|
||||
# Set of tasks being executed. When a task finishes running it gets
|
||||
# automatically removed from the set. When we cleanup we wait for all
|
||||
@@ -103,18 +123,21 @@ class BaseObject(ABC):
|
||||
Can be sync or async.
|
||||
"""
|
||||
if event_name in self._event_handlers:
|
||||
self._event_handlers[event_name].append(handler)
|
||||
self._event_handlers[event_name].handlers.append(handler)
|
||||
else:
|
||||
logger.warning(f"Event handler {event_name} not registered")
|
||||
|
||||
def _register_event_handler(self, event_name: str):
|
||||
def _register_event_handler(self, event_name: str, sync: bool = False):
|
||||
"""Register an event handler type.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event type to register.
|
||||
sync: Whether this event handler will be executed in a task.
|
||||
"""
|
||||
if event_name not in self._event_handlers:
|
||||
self._event_handlers[event_name] = []
|
||||
self._event_handlers[event_name] = EventHandler(
|
||||
name=event_name, handlers=[], is_sync=sync
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Event handler {event_name} not registered")
|
||||
|
||||
@@ -126,34 +149,43 @@ class BaseObject(ABC):
|
||||
*args: Positional arguments to pass to event handlers.
|
||||
**kwargs: Keyword arguments to pass to event handlers.
|
||||
"""
|
||||
# If we haven't registered an event handler, we don't need to do
|
||||
# anything.
|
||||
if not self._event_handlers.get(event_name):
|
||||
if event_name not in self._event_handlers:
|
||||
return
|
||||
|
||||
# Create the task.
|
||||
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
|
||||
event_handler = self._event_handlers[event_name]
|
||||
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
for handler in event_handler.handlers:
|
||||
if event_handler.is_sync:
|
||||
# Just run the handler.
|
||||
await self._run_handler(event_handler.name, handler, *args, **kwargs)
|
||||
else:
|
||||
# Create the task. Note that this is a task per each function
|
||||
# handler. Users can register to an event handler multiple
|
||||
# times.
|
||||
task = asyncio.create_task(
|
||||
self._run_handler(event_handler.name, handler, *args, **kwargs)
|
||||
)
|
||||
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
|
||||
async def _run_task(self, event_name: str, *args, **kwargs):
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
|
||||
async def _run_handler(self, event_name: str, handler, *args, **kwargs):
|
||||
"""Execute all handlers for an event.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event being handled.
|
||||
event_name: The event name for this handler.
|
||||
handler: The handler function to run.
|
||||
*args: Positional arguments to pass to handlers.
|
||||
**kwargs: Keyword arguments to pass to handlers.
|
||||
"""
|
||||
try:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(self, *args, **kwargs)
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(self, *args, **kwargs)
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
||||
|
||||
|
||||
67
tests/test_frame_processor.py
Normal file
67
tests/test_frame_processor.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
TextFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_interruption_and_wait(self):
|
||||
class DelayFrameProcessor(FrameProcessor):
|
||||
"""This processors just gives time to the event loop to change
|
||||
between tasks. Otherwise things happen to fast."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await asyncio.sleep(0.1)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class InterruptFrameProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(TransportMessageUrgentFrame(message=frame.text))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
# Just a random interruption to make sure we don't clear anything
|
||||
# before the actual `InterruptionTaskFrame` interruption.
|
||||
InterruptionFrame(),
|
||||
# This will generate an `InterruptionTaskFrame` and will wait for an
|
||||
# `InterruptionFrame`.
|
||||
TextFrame(text="Hello from Pipecat!"),
|
||||
# Just give time for everything to complete.
|
||||
SleepFrame(sleep=0.5),
|
||||
EndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
InterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
EndFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
send_end_frame=False,
|
||||
)
|
||||
998
tests/test_get_llm_invocation_params.py
Normal file
998
tests/test_get_llm_invocation_params.py
Normal file
@@ -0,0 +1,998 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for LLM adapters' get_llm_invocation_params() method.
|
||||
|
||||
These tests focus specifically on the "messages" field generation for different adapters, ensuring:
|
||||
|
||||
For OpenAI adapter:
|
||||
1. LLMStandardMessage objects are passed through unchanged
|
||||
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
|
||||
3. Complex message structures (like multi-part content) are preserved
|
||||
4. System instructions are preserved throughout messages at any position
|
||||
|
||||
For Gemini adapter:
|
||||
1. LLMStandardMessage objects are converted to Gemini Content format
|
||||
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
|
||||
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
|
||||
4. System messages are extracted as system_instruction (without duplication)
|
||||
5. Single system instruction is converted to user message when no other messages exist
|
||||
6. Multiple system instructions: first extracted, later ones converted to user messages
|
||||
|
||||
For Anthropic adapter:
|
||||
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
|
||||
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
|
||||
For AWS Bedrock adapter:
|
||||
1. LLMStandardMessage objects are converted to AWS Bedrock format
|
||||
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from google.genai.types import Content, Part
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = OpenAILLMAdapter()
|
||||
|
||||
def test_standard_messages_passed_through_unchanged(self):
|
||||
"""Test that LLMStandardMessage objects are passed through unchanged to OpenAI params."""
|
||||
# Create standard messages (OpenAI format)
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify messages are passed through unchanged
|
||||
self.assertEqual(params["messages"], standard_messages)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# Verify content matches exactly
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
|
||||
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that OpenAI-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Gemini specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
self.adapter.create_llm_specific_message(
|
||||
{"role": "assistant", "content": "OpenAI specific response"}
|
||||
),
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only include standard messages and OpenAI-specific ones
|
||||
# (3 total: system, standard user, openai assistant)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# Verify the correct messages are included
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
self.assertEqual(params["messages"][1]["content"], "Standard user message")
|
||||
self.assertEqual(
|
||||
params["messages"][2], {"role": "assistant", "content": "OpenAI specific response"}
|
||||
)
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message content (like multi-part messages) is preserved."""
|
||||
# Create a message with complex content structure (text + image)
|
||||
complex_image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD..."},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with multiple text blocks
|
||||
multi_text_message = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me analyze this step by step:"},
|
||||
{"type": "text", "text": "1. First, I'll examine the visual elements"},
|
||||
{"type": "text", "text": "2. Then I'll provide my conclusions"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant that can analyze images."},
|
||||
complex_image_message,
|
||||
multi_text_message,
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify complex content is preserved
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][1], complex_image_message)
|
||||
self.assertEqual(params["messages"][2], multi_text_message)
|
||||
|
||||
# Verify the image message structure is maintained
|
||||
image_content = params["messages"][1]["content"]
|
||||
self.assertIsInstance(image_content, list)
|
||||
self.assertEqual(len(image_content), 2)
|
||||
self.assertEqual(image_content[0]["type"], "text")
|
||||
self.assertEqual(image_content[1]["type"], "image_url")
|
||||
|
||||
# Verify the multi-text message structure is maintained
|
||||
text_content = params["messages"][2]["content"]
|
||||
self.assertIsInstance(text_content, list)
|
||||
self.assertEqual(len(text_content), 3)
|
||||
for i, text_block in enumerate(text_content):
|
||||
self.assertEqual(text_block["type"], "text")
|
||||
self.assertEqual(text_content[0]["text"], "Let me analyze this step by step:")
|
||||
self.assertEqual(text_content[1]["text"], "1. First, I'll examine the visual elements")
|
||||
self.assertEqual(text_content[2]["text"], "2. Then I'll provide my conclusions")
|
||||
|
||||
def test_system_instructions_preserved_throughout_messages(self):
|
||||
"""Test that OpenAI adapter preserves system instructions sprinkled throughout messages."""
|
||||
# Create messages with system instructions at different positions
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."},
|
||||
{"role": "user", "content": "Tell me about Python."},
|
||||
{"role": "system", "content": "Use simple language."},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# OpenAI should preserve all messages unchanged, including multiple system messages
|
||||
self.assertEqual(len(params["messages"]), 7)
|
||||
|
||||
# Verify system messages are preserved at their original positions
|
||||
self.assertEqual(params["messages"][0]["role"], "system")
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
self.assertEqual(params["messages"][3]["role"], "system")
|
||||
self.assertEqual(params["messages"][3]["content"], "Remember to be concise.")
|
||||
|
||||
self.assertEqual(params["messages"][5]["role"], "system")
|
||||
self.assertEqual(params["messages"][5]["content"], "Use simple language.")
|
||||
|
||||
# Verify other messages remain unchanged
|
||||
self.assertEqual(params["messages"][1]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][4]["role"], "user")
|
||||
self.assertEqual(params["messages"][6]["role"], "assistant")
|
||||
|
||||
|
||||
class TestGeminiGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = GeminiLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_gemini_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to Gemini Content format."""
|
||||
# Create standard messages (OpenAI format)
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction is extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are converted to Gemini format (2 messages: user + model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertIsInstance(user_msg, Content)
|
||||
self.assertEqual(user_msg.role, "user")
|
||||
self.assertEqual(len(user_msg.parts), 1)
|
||||
self.assertEqual(user_msg.parts[0].text, "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant -> model)
|
||||
model_msg = params["messages"][1]
|
||||
self.assertIsInstance(model_msg, Content)
|
||||
self.assertEqual(model_msg.role, "model")
|
||||
self.assertEqual(len(model_msg.parts), 1)
|
||||
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Gemini-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific message"}
|
||||
),
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
self.adapter.create_llm_specific_message(
|
||||
Content(role="model", parts=[Part(text="Gemini specific response")]),
|
||||
),
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only include standard messages and Gemini-specific ones
|
||||
# (2 total: converted standard user + gemini model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify the correct messages are included
|
||||
self.assertEqual(params["messages"][0].role, "user")
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "Standard user message")
|
||||
|
||||
self.assertEqual(params["messages"][1].role, "model")
|
||||
self.assertEqual(params["messages"][1].parts[0].text, "Gemini specific response")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message content (like multi-part messages) is preserved and converted.
|
||||
|
||||
This test covers image, audio, and multi-text content conversion to Gemini format.
|
||||
"""
|
||||
# Create a message with complex content structure (text + image)
|
||||
# Using a minimal valid base64 image data
|
||||
complex_image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with multiple text blocks
|
||||
multi_text_message = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me analyze this step by step:"},
|
||||
{"type": "text", "text": "1. First, I'll examine the visual elements"},
|
||||
{"type": "text", "text": "2. Then I'll provide my conclusions"},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with audio input (text + audio)
|
||||
# Using a minimal valid base64 audio data (16 bytes of WAV header)
|
||||
audio_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you transcribe this audio?"},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=",
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can analyze images and audio.",
|
||||
},
|
||||
complex_image_message,
|
||||
multi_text_message,
|
||||
audio_message,
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(
|
||||
params["system_instruction"],
|
||||
"You are a helpful assistant that can analyze images and audio.",
|
||||
)
|
||||
|
||||
# Verify complex content is converted to Gemini format
|
||||
# Note: Gemini adapter may add system instruction back as user message in some cases
|
||||
self.assertGreaterEqual(len(params["messages"]), 3)
|
||||
|
||||
# Find the different message types
|
||||
user_with_image = None
|
||||
model_with_text = None
|
||||
user_with_audio = None
|
||||
|
||||
for msg in params["messages"]:
|
||||
if msg.role == "user" and len(msg.parts) == 2:
|
||||
# Check if it's image or audio based on the text content
|
||||
if hasattr(msg.parts[0], "text") and "image" in msg.parts[0].text:
|
||||
user_with_image = msg
|
||||
elif hasattr(msg.parts[0], "text") and "audio" in msg.parts[0].text:
|
||||
user_with_audio = msg
|
||||
elif msg.role == "model" and len(msg.parts) == 3:
|
||||
model_with_text = msg
|
||||
|
||||
# Verify the image message structure is converted properly
|
||||
self.assertIsNotNone(user_with_image, "Should have user message with image")
|
||||
self.assertEqual(len(user_with_image.parts), 2)
|
||||
|
||||
# First part should be text
|
||||
self.assertEqual(user_with_image.parts[0].text, "What's in this image?")
|
||||
|
||||
# Second part should be image data (converted to Blob)
|
||||
self.assertIsNotNone(user_with_image.parts[1].inline_data)
|
||||
self.assertEqual(user_with_image.parts[1].inline_data.mime_type, "image/jpeg")
|
||||
|
||||
# Verify the audio message structure is converted properly
|
||||
self.assertIsNotNone(user_with_audio, "Should have user message with audio")
|
||||
self.assertEqual(len(user_with_audio.parts), 2)
|
||||
|
||||
# First part should be text
|
||||
self.assertEqual(user_with_audio.parts[0].text, "Can you transcribe this audio?")
|
||||
|
||||
# Second part should be audio data (converted to Blob)
|
||||
self.assertIsNotNone(user_with_audio.parts[1].inline_data)
|
||||
self.assertEqual(user_with_audio.parts[1].inline_data.mime_type, "audio/wav")
|
||||
|
||||
# Verify the multi-text message structure is converted properly
|
||||
self.assertIsNotNone(model_with_text, "Should have model message with multi-text")
|
||||
self.assertEqual(len(model_with_text.parts), 3)
|
||||
|
||||
# All parts should be text
|
||||
expected_texts = [
|
||||
"Let me analyze this step by step:",
|
||||
"1. First, I'll examine the visual elements",
|
||||
"2. Then I'll provide my conclusions",
|
||||
]
|
||||
for i, expected_text in enumerate(expected_texts):
|
||||
self.assertEqual(model_with_text.parts[i].text, expected_text)
|
||||
|
||||
def test_single_system_instruction_converted_to_user(self):
|
||||
"""Test that when there's only a system instruction, it gets converted to user message."""
|
||||
# Create context with only a system message
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages=messages)
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System instruction should be extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# But since there are no other messages, it should also be added back as a user message
|
||||
self.assertEqual(len(params["messages"]), 1)
|
||||
self.assertEqual(params["messages"][0].role, "user")
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "You are a helpful assistant.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
# Create messages with multiple system instructions
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."},
|
||||
{"role": "user", "content": "Tell me about Python."},
|
||||
{"role": "system", "content": "Use simple language."},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages=messages)
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# First system instruction should be extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 6 messages (original 7 minus 1 system instruction that was extracted)
|
||||
self.assertEqual(len(params["messages"]), 6)
|
||||
|
||||
# Find the converted system messages (should be user role now)
|
||||
converted_system_messages = []
|
||||
for msg in params["messages"]:
|
||||
if msg.role == "user" and (
|
||||
msg.parts[0].text == "Remember to be concise."
|
||||
or msg.parts[0].text == "Use simple language."
|
||||
):
|
||||
converted_system_messages.append(msg.parts[0].text)
|
||||
|
||||
# Should have 2 converted system messages
|
||||
self.assertEqual(len(converted_system_messages), 2)
|
||||
self.assertIn("Remember to be concise.", converted_system_messages)
|
||||
self.assertIn("Use simple language.", converted_system_messages)
|
||||
|
||||
# Verify that regular user and assistant messages are preserved
|
||||
user_messages = [msg for msg in params["messages"] if msg.role == "user"]
|
||||
model_messages = [msg for msg in params["messages"] if msg.role == "model"]
|
||||
|
||||
# Should have 4 user messages: 2 original + 2 converted from system
|
||||
self.assertEqual(len(user_messages), 4)
|
||||
# Should have 2 model messages (converted from assistant)
|
||||
self.assertEqual(len(model_messages), 2)
|
||||
|
||||
|
||||
class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = AnthropicLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_anthropic_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to Anthropic MessageParam format."""
|
||||
# Create standard messages
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify system instruction is extracted
|
||||
self.assertEqual(params["system"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are in the params (2 messages after system extraction)
|
||||
self.assertIn("messages", params)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertEqual(user_msg["content"], "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant)
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Anthropic-specific messages are included and others are filtered out."""
|
||||
# Create anthropic-specific message content
|
||||
anthropic_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/jpeg", "data": "fake_data"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(anthropic_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + anthropic-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message: merged user message (standard + anthropic-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + anthropic text + anthropic image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "image")
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First user message"},
|
||||
{"role": "user", "content": "Second user message"},
|
||||
{"role": "user", "content": "Third user message"},
|
||||
{"role": "assistant", "content": "First assistant message"},
|
||||
{"role": "assistant", "content": "Second assistant message"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should have 2 messages after merging (1 user, 1 assistant)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check merged user message
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "First user message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
|
||||
|
||||
# Check merged assistant message
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 2)
|
||||
self.assertEqual(assistant_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
|
||||
self.assertEqual(assistant_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
|
||||
|
||||
def test_empty_text_converted_to_empty_placeholder(self):
|
||||
"""Test that empty text content is converted to "(empty)" string."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""}, # Empty string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""}, # Empty text in list content
|
||||
{"type": "text", "text": "Valid text"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Check that empty string content was converted
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["content"], "(empty)")
|
||||
|
||||
# Check that empty text in list content was converted
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message structures (text + image) are properly converted to Anthropic format."""
|
||||
# Create a complex message with both text and image content
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,fake_image_data"},
|
||||
},
|
||||
{"type": "text", "text": "Please describe it in detail."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
complex_message,
|
||||
{"role": "assistant", "content": "I can see the image clearly."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify complex message structure is preserved and converted
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
|
||||
# Note: Anthropic adapter reorders single images to come before text, as per Anthropic docs
|
||||
# Check image part (should be moved to first position and converted from image_url to image)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "image")
|
||||
self.assertIn("source", user_msg["content"][0])
|
||||
self.assertEqual(user_msg["content"][0]["source"]["type"], "base64")
|
||||
self.assertEqual(user_msg["content"][0]["source"]["media_type"], "image/jpeg")
|
||||
self.assertEqual(user_msg["content"][0]["source"]["data"], "fake_image_data")
|
||||
|
||||
# Check first text part (moved to second position)
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
|
||||
|
||||
# Check second text part (moved to third position)
|
||||
self.assertEqual(user_msg["content"][2]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."}, # Later system message
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# System instruction should be extracted from first message
|
||||
self.assertEqual(params["system"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 3 messages remaining (system message was removed, later system converted to user)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"], "Hello")
|
||||
self.assertEqual(params["messages"][1]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][1]["content"], "Hi there!")
|
||||
|
||||
# Later system message should be converted to user role
|
||||
self.assertEqual(params["messages"][2]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["content"], "Remember to be concise.")
|
||||
|
||||
def test_single_system_message_converted_to_user(self):
|
||||
"""Test that a single system message is converted to user role when no other messages exist."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# System should be NOT_GIVEN since we only have one message
|
||||
from anthropic import NOT_GIVEN
|
||||
|
||||
self.assertEqual(params["system"], NOT_GIVEN)
|
||||
|
||||
# Single system message should be converted to user role
|
||||
self.assertEqual(len(params["messages"]), 1)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
|
||||
class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = AWSBedrockLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_aws_bedrock_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to AWS Bedrock format."""
|
||||
# Create standard messages
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction is extracted (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are in the params (2 messages after system extraction)
|
||||
self.assertIn("messages", params)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user) - should be converted to AWS Bedrock format
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 1)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant) - should be converted to AWS Bedrock format
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 1)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that AWS-specific messages are included and others are filtered out."""
|
||||
# Create aws-specific message content (which is what AWS Bedrock uses)
|
||||
aws_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello"},
|
||||
{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(message=aws_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + aws-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message: merged user message (standard + aws-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + aws text + aws image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertIn("image", user_msg["content"][2])
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First user message"},
|
||||
{"role": "user", "content": "Second user message"},
|
||||
{"role": "user", "content": "Third user message"},
|
||||
{"role": "assistant", "content": "First assistant message"},
|
||||
{"role": "assistant", "content": "Second assistant message"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should have 2 messages after merging (1 user, 1 assistant)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check merged user message
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "First user message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
|
||||
|
||||
# Check merged assistant message
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 2)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
|
||||
|
||||
def test_empty_text_converted_to_empty_placeholder(self):
|
||||
"""Test that empty text content is converted to "(empty)" string."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""}, # Empty string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""}, # Empty text in list content
|
||||
{"type": "text", "text": "Valid text"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Check that empty string content was converted
|
||||
user_msg = params["messages"][0]
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "(empty)")
|
||||
|
||||
# Check that empty text in list content was converted
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message structures (text + image) are properly converted to AWS Bedrock format."""
|
||||
# Create a complex message with both text and image content
|
||||
# Use a valid base64 string for the image
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Please describe it in detail."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
complex_message,
|
||||
{"role": "assistant", "content": "I can see the image clearly."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify complex message structure is preserved and converted
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
|
||||
# Note: AWS Bedrock adapter reorders single images to come before text, like Anthropic
|
||||
# Check image part (should be moved to first position and converted from image_url to image)
|
||||
self.assertIn("image", user_msg["content"][0])
|
||||
self.assertEqual(user_msg["content"][0]["image"]["format"], "jpeg")
|
||||
self.assertIn("source", user_msg["content"][0]["image"])
|
||||
self.assertIn("bytes", user_msg["content"][0]["image"]["source"])
|
||||
|
||||
# Check first text part (moved to second position)
|
||||
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
|
||||
|
||||
# Check second text part (moved to third position)
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."}, # Later system message
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System instruction should be extracted from first message (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 3 messages remaining (system message was removed, later system converted to user)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
|
||||
self.assertEqual(params["messages"][1]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Hi there!")
|
||||
|
||||
# Later system message should be converted to user role
|
||||
self.assertEqual(params["messages"][2]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["content"][0]["text"], "Remember to be concise.")
|
||||
|
||||
def test_single_system_message_handling(self):
|
||||
"""Test that a single system message is extracted as system parameter and no messages remain."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System should be extracted (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# No messages should remain after system extraction
|
||||
self.assertEqual(len(params["messages"]), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -196,10 +196,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal start_received
|
||||
start_received = True
|
||||
|
||||
@task.event_handler("on_pipeline_ended")
|
||||
async def on_pipeline_ended(task, frame: EndFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame: Frame):
|
||||
nonlocal end_received
|
||||
end_received = True
|
||||
end_received = isinstance(frame, EndFrame)
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
@@ -214,10 +214,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@task.event_handler("on_pipeline_stopped")
|
||||
async def on_pipeline_ended(task, frame: StopFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame: Frame):
|
||||
nonlocal stop_received
|
||||
stop_received = True
|
||||
stop_received = isinstance(frame, StopFrame)
|
||||
|
||||
await task.queue_frame(StopFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
@@ -441,10 +441,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def on_pipeline_started(task: PipelineTask, frame: StartFrame):
|
||||
await task.cancel()
|
||||
|
||||
@task.event_handler("on_pipeline_cancelled")
|
||||
async def on_pipeline_cancelled(task: PipelineTask, frame: CancelFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
cancelled = isinstance(frame, CancelFrame)
|
||||
|
||||
try:
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
261
tests/test_run_inference.py
Normal file
261
tests/test_run_inference.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from anthropic import NOT_GIVEN
|
||||
from openai import NotGiven
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response."""
|
||||
# Create service with mocked client
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
]
|
||||
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
|
||||
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello! How can I help you today?"
|
||||
service._client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
service._client.chat.completions.create.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
messages=test_messages,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_client_exception():
|
||||
"""Test that exceptions from the client are propagated."""
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
|
||||
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.chat.completions.create.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for Anthropic."""
|
||||
# Create service with mocked client
|
||||
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
test_system = "You are a helpful assistant"
|
||||
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
|
||||
messages=test_messages, system=test_system, tools=[]
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Hello! How can I help you today?"
|
||||
service._client.messages.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(
|
||||
mock_context, enable_prompt_caching=False
|
||||
)
|
||||
service._client.messages.create.assert_called_once_with(
|
||||
model="claude-3-sonnet-20240229",
|
||||
messages=test_messages,
|
||||
system=test_system,
|
||||
max_tokens=8192,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_client_exception():
|
||||
"""Test that exceptions from the Anthropic client are propagated."""
|
||||
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
|
||||
messages=[], system="Test system", tools=[]
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.messages.create.side_effect = Exception("Anthropic API Error")
|
||||
|
||||
with pytest.raises(Exception, match="Anthropic API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for Google."""
|
||||
# Create service with mocked client
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
test_system = "You are a helpful assistant"
|
||||
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
|
||||
messages=test_messages, system_instruction=test_system, tools=NotGiven()
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [MagicMock()]
|
||||
mock_response.candidates[0].content = MagicMock()
|
||||
mock_response.candidates[0].content.parts = [MagicMock()]
|
||||
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
service._client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_client_exception():
|
||||
"""Test that exceptions from the Google client are propagated."""
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
|
||||
messages=[], system_instruction="Test system", tools=NotGiven()
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(
|
||||
side_effect=Exception("Google API Error")
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Google API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
|
||||
# Create service and patch the session client method
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
|
||||
test_system = [{"text": "You are a helpful assistant"}]
|
||||
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
|
||||
messages=test_messages, system=test_system, tools=[], tool_choice=None
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock the client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
|
||||
}
|
||||
mock_client.converse.return_value = mock_response
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
async def mock_client_cm(*args, **kwargs):
|
||||
return mock_client
|
||||
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
mock_client.converse.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_client_exception():
|
||||
"""Test that exceptions from the AWS Bedrock client are propagated."""
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
|
||||
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock AWS client to raise exception
|
||||
mock_client = AsyncMock()
|
||||
mock_client.converse.side_effect = Exception("Bedrock API Error")
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
with pytest.raises(Exception, match="Bedrock API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
303
tests/test_service_switcher.py
Normal file
303
tests/test_service_switcher.py
Normal file
@@ -0,0 +1,303 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Unit tests for ServiceSwitcher and related components."""
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
class MockFrameProcessor(FrameProcessor):
|
||||
"""A test frame processor that tracks which frames it has processed."""
|
||||
|
||||
def __init__(self, test_name: str, **kwargs):
|
||||
"""Initialize the test processor with a name.
|
||||
|
||||
Args:
|
||||
test_name: A unique name for this processor instance.
|
||||
**kwargs: Additional arguments passed to the parent FrameProcessor.
|
||||
"""
|
||||
super().__init__(name=test_name, **kwargs)
|
||||
self.test_name = test_name
|
||||
self.processed_frames = []
|
||||
self.frame_count = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an incoming frame and track it.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
self.processed_frames.append(frame)
|
||||
self.frame_count += 1
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def reset_counters(self):
|
||||
"""Reset the frame tracking counters."""
|
||||
self.processed_frames = []
|
||||
self.frame_count = 0
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_services(self):
|
||||
"""Test initialization with a list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
self.assertEqual(strategy.services, self.services)
|
||||
self.assertEqual(strategy.active_service, self.service1) # First service should be active
|
||||
|
||||
def test_init_with_empty_services(self):
|
||||
"""Test initialization with an empty list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual([])
|
||||
|
||||
self.assertEqual(strategy.services, [])
|
||||
self.assertIsNone(strategy.active_service)
|
||||
|
||||
def test_handle_manually_switch_service_frame(self):
|
||||
"""Test manual service switching with ManuallySwitchServiceFrame."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
# Initially service1 should be active
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
self.assertNotEqual(strategy.active_service, self.service3)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
def test_handle_frame_unsupported_frame_type(self):
|
||||
"""Test that unsupported frame types raise an error."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIn("Unsupported frame type", str(context.exception))
|
||||
|
||||
|
||||
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcher."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_manual_strategy(self):
|
||||
"""Test initialization with manual strategy."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
self.assertEqual(switcher.services, self.services)
|
||||
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
|
||||
self.assertEqual(switcher.strategy.services, self.services)
|
||||
|
||||
async def test_default_active_service(self):
|
||||
"""Test that the initially-active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
service.reset_counters()
|
||||
|
||||
# Send some test frames
|
||||
frames_to_send = [
|
||||
TextFrame(text="Hello 1"),
|
||||
TextFrame(text="Hello 2"),
|
||||
TextFrame(text="Hello 3"),
|
||||
]
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[TextFrame, TextFrame, TextFrame],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
# Only service1 should have processed the text frames
|
||||
# Note: The service also receives StartFrame and EndFrame, so count those too
|
||||
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
|
||||
self.assertEqual(len(text_frames), 3)
|
||||
|
||||
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service3_text_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
self.assertEqual(len(service2_text_frames), 0)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
# Verify the actual text frames processed
|
||||
for i, frame in enumerate(text_frames):
|
||||
self.assertEqual(frame.text, f"Hello {i + 1}")
|
||||
|
||||
async def test_service_switching(self):
|
||||
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
service.reset_counters()
|
||||
|
||||
# Send a test frame, a switch frame, and another test frame
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[
|
||||
TextFrame("Hello 1"),
|
||||
ManuallySwitchServiceFrame(service=self.service2),
|
||||
TextFrame("Hello 2"),
|
||||
],
|
||||
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
# Verify service2 received the frame
|
||||
service1_text_frames = [
|
||||
f for f in self.service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service3_text_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
self.assertEqual(len(service1_text_frames), 1)
|
||||
self.assertEqual(len(service2_text_frames), 1)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
self.assertEqual(service1_text_frames[0].text, "Hello 1")
|
||||
self.assertEqual(service2_text_frames[0].text, "Hello 2")
|
||||
|
||||
async def test_multi_service_switcher_targeting(self):
|
||||
"""Test that ManuallySwitchServiceFrame targets the correct ServiceSwitcher in a multi-switcher pipeline."""
|
||||
# Create services for first switcher
|
||||
switcher1_service1 = MockFrameProcessor("switcher1_service1")
|
||||
switcher1_service2 = MockFrameProcessor("switcher1_service2")
|
||||
switcher1_services = [switcher1_service1, switcher1_service2]
|
||||
|
||||
# Create services for second switcher
|
||||
switcher2_service1 = MockFrameProcessor("switcher2_service1")
|
||||
switcher2_service2 = MockFrameProcessor("switcher2_service2")
|
||||
switcher2_services = [switcher2_service1, switcher2_service2]
|
||||
|
||||
# Create two service switchers
|
||||
switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual)
|
||||
switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Create a pipeline with both switchers: switcher1 -> switcher2
|
||||
pipeline = Pipeline([switcher1, switcher2])
|
||||
|
||||
# Reset counters
|
||||
for service in switcher1_services + switcher2_services:
|
||||
service.reset_counters()
|
||||
|
||||
# Initially, both switchers should use their first services
|
||||
self.assertEqual(switcher1.strategy.active_service, switcher1_service1)
|
||||
self.assertEqual(switcher2.strategy.active_service, switcher2_service1)
|
||||
|
||||
# Send frames to test the pipeline:
|
||||
# 1. Text frame (should go through both switchers' active services)
|
||||
# 2. Switch frame targeting switcher1's second service
|
||||
# 3. Text frame (should go through switcher1's new service and switcher2's original service)
|
||||
# 4. Switch frame targeting switcher2's second service
|
||||
# 5. Text frame (should go through switcher1's current service and switcher2's new service)
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[
|
||||
TextFrame("Before any switches"),
|
||||
ManuallySwitchServiceFrame(service=switcher1_service2), # Switch first switcher
|
||||
TextFrame("After switching first switcher"),
|
||||
ManuallySwitchServiceFrame(service=switcher2_service2), # Switch second switcher
|
||||
TextFrame("After switching second switcher"),
|
||||
],
|
||||
expected_down_frames=[
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
# Verify the active services changed correctly
|
||||
self.assertEqual(switcher1.strategy.active_service, switcher1_service2)
|
||||
self.assertEqual(switcher2.strategy.active_service, switcher2_service2)
|
||||
|
||||
# Verify frame distribution:
|
||||
# First text frame should go through switcher1_service1 and switcher2_service1
|
||||
switcher1_service1_texts = [
|
||||
f for f in switcher1_service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
switcher2_service1_texts = [
|
||||
f for f in switcher2_service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
# Second text frame should go through switcher1_service2 and switcher2_service1
|
||||
switcher1_service2_texts = [
|
||||
f for f in switcher1_service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
# Third text frame should go through switcher1_service2 and switcher2_service2
|
||||
switcher2_service2_texts = [
|
||||
f for f in switcher2_service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
# Verify frame counts and content
|
||||
self.assertEqual(len(switcher1_service1_texts), 1)
|
||||
self.assertEqual(switcher1_service1_texts[0].text, "Before any switches")
|
||||
|
||||
self.assertEqual(len(switcher1_service2_texts), 2)
|
||||
self.assertEqual(switcher1_service2_texts[0].text, "After switching first switcher")
|
||||
self.assertEqual(switcher1_service2_texts[1].text, "After switching second switcher")
|
||||
|
||||
self.assertEqual(len(switcher2_service1_texts), 2)
|
||||
self.assertEqual(switcher2_service1_texts[0].text, "Before any switches")
|
||||
self.assertEqual(switcher2_service1_texts[1].text, "After switching first switcher")
|
||||
|
||||
self.assertEqual(len(switcher2_service2_texts), 1)
|
||||
self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user