Compare commits

...

66 Commits

Author SHA1 Message Date
Chad Bailey
05863ded53 add remote participant updates to DailyTransport 2025-09-19 21:15:07 +00:00
chadbailey59
6ab4a48d8f Add remote participant updates to DailyTransport (#2694)
* add remote participant updates to DailyTransport

* cleanup

* cleanup

* ruff cleanup again
2025-09-19 21:15:07 +00:00
Aleix Conchillo Flaqué
89e0092159 BaseObject: run each handler for the same event in a separate task 2025-09-19 21:15:07 +00:00
Aleix Conchillo Flaqué
0de31dab79 LiveKitTransport: added synchronous before_disconnect event 2025-09-19 21:15:07 +00:00
Aleix Conchillo Flaqué
10ff93307d DailyTransport: added synchronous on_before_disconnect event 2025-09-19 21:15:07 +00:00
Aleix Conchillo Flaqué
414c9e3bc8 BaseObject: allow synchronous event handlers 2025-09-19 21:15:07 +00:00
Chad Bailey
ba64f126a3 reset paused frame processors after interruption 2025-09-19 18:15:01 +00:00
Chad Bailey
d28e3881a7 add remote participant updates to DailyTransport 2025-09-19 18:11:06 +00:00
Mark Backman
7df7395dd1 Merge pull request #2692 from pipecat-ai/mb/lazy-load-smallwebrtc-request
Lazy load SmallWebRTCRequest, SmallWebRTCRequestHandler in runner
2025-09-19 10:43:43 -07:00
Mark Backman
0885bc9cdf Lazy load SmallWebRTCRequest, SmallWebRTCRequestHandler in runner 2025-09-19 13:28:01 -04:00
Aleix Conchillo Flaqué
0204f6a95d Merge pull request #2686 from pipecat-ai/aleix/silero-vad-v6
audio(vad): update Silero VAD model to v6
2025-09-18 20:31:10 -07:00
Mark Backman
b0bf653f04 Merge pull request #2679 from pipecat-ai/mb/gladia-remove-confidence
GladiaSTTService: deprecate confidence arg
2025-09-18 17:41:33 -07:00
Mark Backman
e8a676eb36 GladiaSTTService: deprecate confidence arg 2025-09-18 20:38:53 -04:00
Mark Backman
ca96eef1f3 Merge pull request #2680 from pipecat-ai/mb/dial-in-session-id
DailyTransport sip_call_transfer now automatically receives session_id
2025-09-18 17:36:51 -07:00
Mark Backman
8e1637d6c7 DailyTransport sip_call_transfer now automatically receives session_id 2025-09-18 20:34:14 -04:00
Filipi da Silva Fuchter
367200c0ad Merge pull request #2682 from pipecat-ai/filipi/smallwebrtc_leak
Smallwebrtc memory leak
2025-09-18 18:56:08 -03:00
Filipi Fuchter
766e1948a6 Mentioning the fix in the changelog. 2025-09-18 18:43:33 -03:00
Aleix Conchillo Flaqué
f369683b8b audio(vad): update Silero VAD model to v6 2025-09-18 14:06:37 -07:00
Aleix Conchillo Flaqué
461025d1cc Merge pull request #2684 from pipecat-ai/aleix/readme-whisker
README: add whisker debugger
2025-09-18 13:27:35 -07:00
Aleix Conchillo Flaqué
ac88706f38 README: add whisker debugger 2025-09-18 13:22:54 -07:00
Filipi Fuchter
93a89449b8 Adding warnings in case queue grows. 2025-09-18 16:43:57 -03:00
Filipi Fuchter
199bf72945 Preventing memory growth if we are not consuming the track. 2025-09-18 16:16:10 -03:00
Filipi Fuchter
d20e4125f6 Updating aiortc to the latest version. 2025-09-18 15:22:46 -03:00
Filipi Fuchter
c1baed642e Script to monitor memory usage. 2025-09-18 14:43:42 -03:00
Mark Backman
33ef68573f Merge pull request #2662 from pelguetat/fix-vertex-ai-global-location-support
feat: add support for global location in Vertex AI base URL
2025-09-18 10:25:10 -07:00
Pablo Elgueta
3c1b41df13 docs: add changelog entry for global location support
- Document the new global location support in GoogleVertexLLMService
- Explain the difference between regional and global API hosts
- Follow Keep a Changelog format
2025-09-18 17:39:03 +01:00
kompfner
fca4ecc73c Merge pull request #2675 from pipecat-ai/pk/service-switcher-logic-simplification
Simplify a bit of logic in `ServiceSwitcher`
2025-09-18 09:17:22 -04:00
Paul Kompfner
cfa333508b Simplify a bit of logic in ServiceSwitcher 2025-09-17 21:03:38 -04:00
Mark Backman
9e7260393a Merge pull request #2671 from pipecat-ai/mb/fix-asyncai-ttstextframe
fix: AsyncAITTSService wasn't pushing TTSTextFrames
2025-09-17 14:06:41 -07:00
Mark Backman
073b585c52 fix: AsyncAITTSService wasn't pushing TTSTextFrames 2025-09-17 16:54:18 -04:00
Aleix Conchillo Flaqué
81c2e51bec Merge pull request #2669 from pipecat-ai/aleix/interruption-task-frame-wait-fixes
interruption task frame wait fixes
2025-09-17 13:47:57 -07:00
Aleix Conchillo Flaqué
42344125b1 tests: add unit tests for push_interruption_task_frame_and_wait() 2025-09-17 13:38:22 -07:00
Aleix Conchillo Flaqué
db5bcfaa51 FrameProcessor: fix push_interruption_task_frame_and_wait() 2025-09-17 13:38:21 -07:00
kompfner
615239b7d2 Merge pull request #2646 from pipecat-ai/pk/service-switcher-unit-tests
`ServiceSwitcher` unit tests (ended up being much more than that)
2025-09-17 16:30:18 -04:00
Paul Kompfner
27f1e9dd69 Update CHANGELOG with a description of the recently-fixed ServiceSwitcher bugs 2025-09-17 16:27:12 -04:00
Paul Kompfner
bd760deff2 Update comment with more detail for posterity 2025-09-17 16:19:31 -04:00
Paul Kompfner
8bc3c89140 Fix a bug preventing usage of multiple ServiceSwitchers in a pipeline 2025-09-17 16:09:18 -04:00
Paul Kompfner
2cd2567a37 Add a unit tests validating that multiple ServiceSwitchers can be used in the same pipeline (currently failing) 2025-09-17 16:04:30 -04:00
Paul Kompfner
5b55988846 Denote a couple of variables are private with a leading underscore 2025-09-17 15:38:28 -04:00
Paul Kompfner
a12392182c Simplify, undoing the change allowing controlling ServiceSwitcher with immediate frames (SystemFrames). Service switcher frames are ControlFrames, which are easier to reason about. We can always build the immediate option later if needed (i.e. if there's sufficient user pull for it) 2025-09-17 15:35:02 -04:00
Paul Kompfner
b814b70e1e Allow controlling ServiceSwitcher with either immediate frames (SystemFrames) or in-order frames (ControlFrames).
Immediate is the "default", i.e. has the more obvious name (e.g. `ManuallySwitchServiceFrame` v `ManuallySwitchServiceControlFrame`), since that's *probably* what users will want to reach for. Also, the immediate frames are more likely to behave like what we had before the last few commits, where the service switch would always "jump the queue" by having an immediate effect once it hit the `ServiceSwitcher` in the pipeline, jumping ahead of frames in front of it destined for the service.
2025-09-17 15:35:02 -04:00
Paul Kompfner
a1f84e1b50 Remove extraneous unit tests 2025-09-17 15:35:02 -04:00
Paul Kompfner
0839b48da8 Fix an issue where the upstream ServiceSwitcherFilter wouldn't get updated with the current active service 2025-09-17 15:35:02 -04:00
Paul Kompfner
de51637b77 Update ServiceSwitcher so that ServiceSwitcherFrames (which might update the currently active service) are processed and have an effect at the expected time. We should be able to, for example, queue:
- A text frame
- A `ManuallySwitchServiceFrame` (which is a `ServiceSwitcherFrame`)
- Another text frame

And expect that the first text frame be handled by the initially active service and the second text frame be handled by the newly active one.

Previously, the `ManuallySwitchServiceFrame` would have an effect too early, causing both text frames to be handled by the newly active service. Why? Because the frame filtering condition was being updated *directly* by the `ServiceSwitcher`, which is upstream from the services it's switching between. It could therefore update the filters *before* the services received the prior frames.
2025-09-17 15:35:02 -04:00
Paul Kompfner
e1b1dc16ec Add unit tests for ServiceSwitcher 2025-09-17 15:35:02 -04:00
Mark Backman
1fe27eb0a2 Merge pull request #2660 from pipecat-ai/mb/fix-user-idle-processor-cancel-task
fix: clean up how UserIdleProcessor handles return False
2025-09-16 14:48:59 -07:00
Mark Backman
d7e1389497 fix: clean up how UserIdleProcessor handles return False 2025-09-16 17:44:06 -04:00
Aleix Conchillo Flaqué
8c7230aa8f Merge pull request #2668 from pipecat-ai/aleix/livekit-update
livekit package update
2025-09-16 14:43:18 -07:00
Aleix Conchillo Flaqué
2cf71239b0 examples(01b): use TTSSpeakFrame instead of TextFrame 2025-09-16 17:18:45 -04:00
Aleix Conchillo Flaqué
ec2c62e32b pyproject: update to livekit 1.0.13
Fixes #2643
2025-09-16 17:18:44 -04:00
Mark Backman
38ce85e9a0 Merge pull request #2667 from zytegalaxy/mcp-serverparameters-typefix
fix: replace `Tuple` type with `TypeAlias` for server params in MCP client
2025-09-16 14:14:59 -07:00
Mark Backman
2279e5a899 Merge pull request #2663 from pipecat-ai/mb/websockets-15
Add support for websockets 15.0
2025-09-16 14:08:36 -07:00
Mark Backman
cce6eb5d87 Merge pull request #2666 from pipecat-ai/mb/update-38b-local-turn-model
38b: Update bundled ONNX smart-turn model
2025-09-16 14:05:12 -07:00
mehrdad
c2b98ae557 fix(lint): fix space format issue 2025-09-16 13:44:15 -07:00
Filipi da Silva Fuchter
727eb12b16 Merge pull request #2648 from pipecat-ai/filipi/pcc_small_webrtc
Creating SmallWebRTCRequestHandler for managing peer connections.
2025-09-16 16:37:04 -03:00
mehrdad
ba96bd05d3 fix: replace Tuple type with TypeAlias for server params in MCP client 2025-09-16 11:44:25 -07:00
Mark Backman
8ead309f8d 38b: Update bundled ONNX smart-turn model 2025-09-16 13:17:14 -04:00
Mark Backman
fad0e55c64 Add websockets-base optional dependency and use for DRY pyproject.toml 2025-09-16 11:24:38 -04:00
Mark Backman
74b1af56a0 Update uv.lock 2025-09-16 11:21:49 -04:00
Mark Backman
6924850ec4 Add support for websockets 15.0 2025-09-16 11:21:49 -04:00
marcus-daily
dfe7815dc5 Smart Turn v3: removing torch and torchaudio deps 2025-09-16 16:02:41 +01:00
Pablo Elgueta
69f0a75882 feat: add support for global location in Vertex AI base URL
- Update _get_base_url method to handle 'global' location case
- Use 'aiplatform.googleapis.com' for global locations
- Use '{location}-aiplatform.googleapis.com' for regional locations
- Maintains backward compatibility with existing regional endpoints
2025-09-16 10:28:22 -03:00
Filipi Fuchter
0a043154f2 Removing not used import. 2025-09-15 10:46:43 -03:00
Filipi Fuchter
5e322eba9e Supporting both single and multiple connection modes. 2025-09-15 10:43:46 -03:00
Filipi Fuchter
11d0c3d46d Refactoring SmallWebRTCRequestHandler. 2025-09-15 09:58:44 -03:00
Filipi Fuchter
95f72f6dce Creating SmallWebRTCRequestHandler for managing peer connections. 2025-09-12 18:15:24 -03:00
26 changed files with 2388 additions and 1439 deletions

View File

@@ -9,6 +9,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `on_before_disconnect` synchronous event to `DailyTransport` and
`LiveKitTransport`.
- It is now possible to register synchronous event handlers. By default, all
event handlers are executed in a separate task. However, in some cases we want
to guarantee order of execution, for example, executing something before
disconnecting a transport.
```python
self._register_event_handler("on_event_name", sync=True)
```
- Added support for global location in `GoogleVertexLLMService`. The service now
supports both regional locations (e.g., "us-east4") and the "global" location
for Vertex AI endpoints. When using "global" location, the service will use
`aiplatform.googleapis.com` as the API host instead of the regional format.
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
fired when the pipeline is done running. This can be the result of a
`StopFrame`, `CancelFrame` or `EndFrame`.
@@ -19,14 +36,64 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
...
```
### Changed
- Updated Silero VAD model to v6.
- Updated `livekit` to 1.0.13.
- `torch` and `torchaudio` are no longer required for running Smart Turn
locally. This avoids gigabytes of dependencies being installed.
- Updated `websockets` dependency to support version 15.0. Removed deprecated
usage of `ConnectionClosed.code` and `ConnectionClosed.reason` attributes in
`AWSTranscribeSTTService` for compatibility.
- Refactored `pyproject.toml` to reduce websockets dependency repetition using
self-referencing extras. All websockets-dependent services now reference a
shared `websockets-base` extra.
### Deprecated
- `GladiaSTTService`'s `confidence` arg is deprecated. `confidence` is no
longer needed to determine which transcription or translation frames to
emit.
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
instead.
### Fixed
- Fixed an issue where multiple handlers for an event would not run in parallel.
- Fixed `DailyTransport.sip_call_transfer()` to automatically use the session
ID from the `on_dialin_connected` event, when not explicitly provided. Now
supports cold transfers (from incoming dial-in calls) by automatically
tracking session IDs from connection events.
- Fixed a memory leak in `SmallWebRTCTransport`. In `aiortc`, when you receive
a `MediaStreamTrack` (audio or video), frames are produced asynchronously. If
the code never consumes these frames, they are queued in memory, causing a
memory leak.
- Fixed an issue in `AsyncAITTSService`, where `TTSTextFrames` were not being
pushed.
- Fixed an issue that would cause `push_interruption_task_frame_and_wait()` to
not wait if a previous interruption had already happened.
- Fixed a couple of bugs in `ServiceSwitcher`:
- Using multiple `ServiceSwitcher`s in a pipeline would result in an error.
- `ServiceSwitcherFrame`s (such as `ManuallySwitchServiceFrame`s) were having
an effect too early, essentially "jumping the queue" in terms of pipeline
frame ordering.
- Fixed a self-cancellation deadlock in `UserIdleProcessor` when returning
`False` from an idle callback. The task now terminates naturally instead of
attempting to cancel itself.
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
when a bot speaks and user input is blocked.

View File

@@ -21,6 +21,8 @@
🧭 Looking to build structured conversations? Check out [Pipecat Flows](https://github.com/pipecat-ai/pipecat-flows) for managing complex conversational states and transitions.
🔍 Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
## 🧠 Why Pipecat?
- **Voice-first**: Integrates speech recognition, text-to-speech, and conversation handling

View File

@@ -11,7 +11,7 @@ import sys
from dotenv import load_dotenv
from loguru import logger
from pipecat.frames.frames import TextFrame
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
@@ -50,7 +50,7 @@ async def main():
async def on_first_participant_joined(transport, participant_id):
await asyncio.sleep(1)
await task.queue_frame(
TextFrame(
TTSSpeakFrame(
"Hello there! How are you doing today? Would you like to talk about the weather?"
)
)

View File

@@ -30,10 +30,6 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
# to the Smart Turn v3 ONNX model file.
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -42,25 +38,19 @@ transport_params = {
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}

View File

@@ -47,32 +47,32 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
aic = [ "aic-sdk~=1.0.1" ]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "websockets>=13.1,<15.0" ]
asyncai = [ "websockets>=13.1,<15.0" ]
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
assemblyai = [ "pipecat-ai[websockets-base]" ]
asyncai = [ "pipecat-ai[websockets-base]" ]
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
cerebras = []
deepseek = []
daily = [ "daily-python~=0.19.9" ]
deepgram = [ "deepgram-sdk~=4.7.0" ]
elevenlabs = [ "websockets>=13.1,<15.0" ]
elevenlabs = [ "pipecat-ai[websockets-base]" ]
fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
gladia = [ "websockets>=13.1,<15.0" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
gladia = [ "pipecat-ai[websockets-base]" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "pipecat-ai[websockets-base]" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]
heygen = [ "livekit>=0.22.0", "websockets>=13.1,<15.0" ]
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
inworld = []
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
koala = [ "pvkoala~=2.0.3" ]
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "websockets>=13.1,<15.0" ]
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "pipecat-ai[websockets-base]" ]
local = [ "pyaudio~=0.2.14" ]
mcp = [ "mcp[cli]~=1.9.4" ]
mem0 = [ "mem0ai~=0.1.94" ]
@@ -80,34 +80,35 @@ mistral = []
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
nim = []
neuphonic = [ "websockets>=13.1,<15.0" ]
neuphonic = [ "pipecat-ai[websockets-base]" ]
noisereduce = [ "noisereduce~=3.0.3" ]
openai = [ "websockets>=13.1,<15.0" ]
openai = [ "pipecat-ai[websockets-base]" ]
openpipe = [ "openpipe~=4.50.0" ]
openrouter = []
perplexity = []
playht = [ "websockets>=13.1,<15.0" ]
playht = [ "pipecat-ai[websockets-base]" ]
qwen = []
rime = [ "websockets>=13.1,<15.0" ]
rime = [ "pipecat-ai[websockets-base]" ]
riva = [ "nvidia-riva-client~=2.21.1" ]
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
sambanova = []
sarvam = [ "websockets>=13.1,<15.0" ]
sarvam = [ "pipecat-ai[websockets-base]" ]
sentry = [ "sentry-sdk~=2.23.1" ]
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
local-smart-turn-v3 = [ "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3", "onnxruntime>=1.20.1, <2" ]
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1, <2" ]
remote-smart-turn = []
silero = [ "onnxruntime>=1.20.1, <2" ]
simli = [ "simli-ai~=0.1.10"]
soniox = [ "websockets>=13.1,<15.0" ]
soniox = [ "pipecat-ai[websockets-base]" ]
soundfile = [ "soundfile~=0.13.0" ]
speechmatics = [ "speechmatics-rt>=0.4.0" ]
tavus=[]
together = []
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
webrtc = [ "aiortc~=1.13.0", "opencv-python~=4.11.0.86" ]
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.117.0" ]
websockets-base = [ "websockets>=13.1,<16.0" ]
whisper = [ "faster-whisper~=1.1.1" ]
[dependency-groups]

12
scripts/mem-watch.sh Executable file
View File

@@ -0,0 +1,12 @@
#!/bin/bash
PID=$1
while true; do
# Clear the screen
clear
# Print the header + RSS in GB
ps -p "$PID" -o pid,comm,rss | \
awk 'NR==1 {print $0, "rss_GB"} NR>1 {printf "%s %s %s %.2f\n", $1,$2,$3,$3/1024/1024}'
sleep 1
done

View File

@@ -98,15 +98,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
inputs = self._feature_extractor(
audio_array,
sampling_rate=16000,
return_tensors="pt",
return_tensors="np",
padding="max_length",
max_length=8 * 16000,
truncation=True,
do_normalize=True,
)
# Convert to numpy and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32)
# Extract features and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
# Run ONNX inference

View File

@@ -1604,7 +1604,7 @@ class MixerEnableFrame(MixerControlFrame):
@dataclass
class ServiceSwitcherFrame(ControlFrame):
"""A base class for frames that control ServiceSwitcher behavior."""
"""A base class for frames that affect ServiceSwitcher behavior."""
pass

View File

@@ -6,9 +6,15 @@
"""Service switcher for switching between different services at runtime, with different switching strategies."""
from dataclasses import dataclass
from typing import Any, Generic, List, Optional, Type, TypeVar
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, ServiceSwitcherFrame
from pipecat.frames.frames import (
ControlFrame,
Frame,
ManuallySwitchServiceFrame,
ServiceSwitcherFrame,
)
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -22,19 +28,6 @@ class ServiceSwitcherStrategy:
self.services = services
self.active_service: Optional[FrameProcessor] = None
def is_active(self, service: FrameProcessor) -> bool:
"""Determine if the given service is the currently active one.
This method should be overridden by subclasses to implement specific logic.
Args:
service: The service to check.
Returns:
True if the given service is the active one, False otherwise.
"""
raise NotImplementedError("Subclasses must implement this method.")
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
"""Handle a frame that controls service switching.
@@ -60,17 +53,6 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
super().__init__(services)
self.active_service = services[0] if services else None
def is_active(self, service: FrameProcessor) -> bool:
"""Check if the given service is the currently active one.
Args:
service: The service to check.
Returns:
True if the given service is the active one, False otherwise.
"""
return service == self.active_service
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
"""Handle a frame that controls service switching.
@@ -79,20 +61,21 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
direction: The direction of the frame (upstream or downstream).
"""
if isinstance(frame, ManuallySwitchServiceFrame):
self._set_active(frame.service)
self._set_active_if_available(frame.service)
else:
raise ValueError(f"Unsupported frame type: {type(frame)}")
def _set_active(self, service: FrameProcessor):
"""Set the active service to the given one.
def _set_active_if_available(self, service: FrameProcessor):
"""Set the active service to the given one, if it is in the list of available services.
If it's not in the list, the request is ignored, as it may have been
intended for another ServiceSwitcher in the pipeline.
Args:
service: The service to set as active.
"""
if service in self.services:
self.active_service = service
else:
raise ValueError(f"Service {service} is not in the list of available services.")
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
@@ -108,6 +91,43 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
self.services = services
self.strategy = strategy
class ServiceSwitcherFilter(FunctionFilter):
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
def __init__(
self,
wrapped_service: FrameProcessor,
active_service: FrameProcessor,
direction: FrameDirection,
):
"""Initialize the service switcher filter with a strategy and direction."""
async def filter(_: Frame) -> bool:
return self._wrapped_service == self._active_service
super().__init__(filter, direction)
self._wrapped_service = wrapped_service
self._active_service = active_service
async def process_frame(self, frame, direction):
"""Process a frame through the filter, handling special internal filter-updating frames."""
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
self._active_service = frame.active_service
# Two ServiceSwitcherFilters "sandwich" a service. Push the
# frame only to update the other side of the sandwich, but
# otherwise don't let it leave the sandwich.
if direction == self._direction:
await self.push_frame(frame, direction)
return
await super().process_frame(frame, direction)
@dataclass
class ServiceSwitcherFilterFrame(ControlFrame):
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
active_service: FrameProcessor
@staticmethod
def _make_pipeline_definitions(
services: List[FrameProcessor], strategy: ServiceSwitcherStrategy
@@ -121,14 +141,18 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
def _make_pipeline_definition(
service: FrameProcessor, strategy: ServiceSwitcherStrategy
) -> Any:
async def filter(frame) -> bool:
_ = frame
return strategy.is_active(service)
return [
FunctionFilter(filter, direction=FrameDirection.DOWNSTREAM),
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
direction=FrameDirection.DOWNSTREAM,
),
service,
FunctionFilter(filter, direction=FrameDirection.UPSTREAM),
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
direction=FrameDirection.UPSTREAM,
),
]
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -142,3 +166,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
if isinstance(frame, ServiceSwitcherFrame):
self.strategy.handle_frame(frame, direction)
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
active_service=self.strategy.active_service
)
await super().process_frame(service_switcher_filter_frame, direction)

View File

@@ -220,6 +220,11 @@ class FrameProcessor(BaseObject):
self.__process_event: Optional[asyncio.Event] = None
self.__process_frame_task: Optional[asyncio.Task] = None
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
# Then we wait for the corresponding `InterruptionFrame` to travel from
# the start of the pipeline back to the processor that sent the
# `InterruptionTaskFrame`. This wait is handled using the following
# event.
self._wait_for_interruption = False
self._wait_interruption_event = asyncio.Event()
@@ -563,11 +568,17 @@ class FrameProcessor(BaseObject):
"""Pause processing of queued frames."""
logger.trace(f"{self}: pausing frame processing")
self.__should_block_frames = True
# We should also unset the process event here, in case it was set immediately after an interruption
if self.__process_event:
self.__process_event.clear()
async def pause_processing_system_frames(self):
"""Pause processing of queued system frames."""
logger.trace(f"{self}: pausing system frame processing")
self.__should_block_system_frames = True
# We should also unset the input event here, in case it was set immediately after an interruption
if self.__input_event:
self.__input_event.clear()
async def resume_processing_frames(self):
"""Resume processing of queued frames."""
@@ -632,7 +643,9 @@ class FrameProcessor(BaseObject):
await self.__internal_push_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
# If we are waiting for an interruption and we get an interruption, then
# we can unblock `push_interruption_task_frame_and_wait()`.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
self._wait_interruption_event.set()
async def push_interruption_task_frame_and_wait(self):

View File

@@ -17,7 +17,6 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
StartFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -185,15 +184,13 @@ class UserIdleProcessor(FrameProcessor):
Runs in a loop until cancelled or callback indicates completion.
"""
while True:
running = True
while running:
try:
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
except asyncio.TimeoutError:
if not self._interrupted:
self._retry_count += 1
should_continue = await self._callback(self, self._retry_count)
if not should_continue:
await self._stop()
break
running = await self._callback(self, self._retry_count)
finally:
self._idle_event.clear()

View File

@@ -70,7 +70,6 @@ import asyncio
import os
import sys
from contextlib import asynccontextmanager
from typing import Dict
from loguru import logger
@@ -183,13 +182,14 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
from pipecat.transports.smallwebrtc.request_handler import (
SmallWebRTCRequest,
SmallWebRTCRequestHandler,
)
except ImportError as e:
logger.error(f"WebRTC transport dependencies not installed: {e}")
return
# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}
# Mount the frontend
app.mount("/client", SmallWebRTCPrebuiltUI)
@@ -198,51 +198,33 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
"""Redirect root requests to client interface."""
return RedirectResponse(url="/client/")
# Initialize the SmallWebRTC request handler
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
esp32_mode=esp32_mode, host=host
)
@app.post("/api/offer")
async def offer(request: dict, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests and manage peer connections."""
pc_id = request.get("pc_id")
if pc_id and pc_id in pcs_map:
pipecat_connection = pcs_map[pc_id]
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request["sdp"],
type=request["type"],
restart_pc=request.get("restart_pc", False),
)
else:
pipecat_connection = SmallWebRTCConnection()
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
"""Handle WebRTC connection closure and cleanup."""
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
pcs_map.pop(webrtc_connection.pc_id, None)
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
# Prepare runner arguments with the callback to run your bot
async def webrtc_connection_callback(connection):
bot_module = _get_bot_module()
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
background_tasks.add_task(bot_module.bot, runner_args)
answer = pipecat_connection.get_answer()
# Apply ESP32 SDP munging if enabled
if esp32_mode and host != "localhost":
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], host)
pcs_map[answer["pc_id"]] = pipecat_connection
# Delegate handling to SmallWebRTCRequestHandler
answer = await small_webrtc_handler.handle_web_request(
request=request,
webrtc_connection_callback=webrtc_connection_callback,
)
return answer
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage FastAPI application lifecycle and cleanup connections."""
yield
coros = [pc.disconnect() for pc in pcs_map.values()]
await asyncio.gather(*coros)
pcs_map.clear()
await small_webrtc_handler.close()
app.router.lifespan_context = lifespan

View File

@@ -119,7 +119,6 @@ class AsyncAITTSService(InterruptibleTTSService):
"""
super().__init__(
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
pause_frame_processing=True,
push_stop_frames=True,
sample_rate=sample_rate,

View File

@@ -532,9 +532,7 @@ class AWSTranscribeSTTService(STTService):
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
)
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
break
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")

View File

@@ -13,6 +13,7 @@ supporting multiple languages, custom vocabulary, and various audio processing o
import asyncio
import base64
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Literal, Optional
import aiohttp
@@ -173,8 +174,6 @@ class _InputParamsDescriptor:
"""Descriptor for backward compatibility with deprecation warning."""
def __get__(self, obj, objtype=None):
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
@@ -208,7 +207,7 @@ class GladiaSTTService(STTService):
api_key: str,
region: Literal["us-west", "eu-west"] | None = None,
url: str = "https://api.gladia.io/v2/live",
confidence: float = 0.5,
confidence: Optional[float] = None,
sample_rate: Optional[int] = None,
model: str = "solaria-1",
params: Optional[GladiaInputParams] = None,
@@ -224,6 +223,11 @@ class GladiaSTTService(STTService):
region: Region used to process audio. eu-west or us-west. Defaults to eu-west.
url: Gladia API URL. Defaults to "https://api.gladia.io/v2/live".
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
.. deprecated:: 0.0.86
The 'confidence' parameter is deprecated and will be removed in a future version.
No confidence threshold is applied.
sample_rate: Audio sample rate in Hz. If None, uses service default.
model: Model to use for transcription. Defaults to "solaria-1".
params: Additional configuration parameters for Gladia service.
@@ -236,7 +240,6 @@ class GladiaSTTService(STTService):
params = params or GladiaInputParams()
# Warn about deprecated language parameter if it's used
if params.language is not None:
with warnings.catch_warnings():
warnings.simplefilter("always")
@@ -247,11 +250,20 @@ class GladiaSTTService(STTService):
stacklevel=2,
)
if confidence:
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"The 'confidence' parameter is deprecated and will be removed in a future version. "
"No confidence threshold is applied.",
DeprecationWarning,
stacklevel=2,
)
self._api_key = api_key
self._region = region
self._url = url
self.set_model_name(model)
self._confidence = confidence
self._params = params
self._websocket = None
self._receive_task = None
@@ -575,43 +587,40 @@ class GladiaSTTService(STTService):
elif content["type"] == "transcript":
utterance = content["data"]["utterance"]
confidence = utterance.get("confidence", 0)
language = utterance["language"]
transcript = utterance["text"]
is_final = content["data"]["is_final"]
if confidence >= self._confidence:
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
)
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
)
elif content["type"] == "translation":
translated_utterance = content["data"]["translated_utterance"]
original_language = content["data"]["original_language"]
translated_language = translated_utterance["language"]
confidence = translated_utterance.get("confidence", 0)
translation = translated_utterance["text"]
if translated_language != original_language and confidence >= self._confidence:
if translated_language != original_language:
await self.push_frame(
TranslationFrame(
translation, "", time_now_iso8601(), translated_language

View File

@@ -83,14 +83,23 @@ class GoogleVertexLLMService(OpenAILLMService):
self._api_key = self._get_api_token(credentials, credentials_path)
super().__init__(
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
api_key=self._api_key,
base_url=base_url,
model=model,
params=params,
**kwargs,
)
@staticmethod
def _get_base_url(params: InputParams) -> str:
"""Construct the base URL for Vertex AI API."""
# Determine the correct API host based on location
if params.location == "global":
api_host = "aiplatform.googleapis.com"
else:
api_host = f"{params.location}-aiplatform.googleapis.com"
return (
f"https://{params.location}-aiplatform.googleapis.com/v1/"
f"https://{api_host}/v1/"
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
)
@@ -118,12 +127,14 @@ class GoogleVertexLLMService(OpenAILLMService):
if credentials:
# Parse and load credentials from JSON string
creds = service_account.Credentials.from_service_account_info(
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
json.loads(credentials),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
elif credentials_path:
# Load credentials from JSON file
creds = service_account.Credentials.from_service_account_file(
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
credentials_path,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
try:

View File

@@ -7,7 +7,7 @@
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
import json
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, TypeAlias
from loguru import logger
@@ -28,6 +28,8 @@ except ModuleNotFoundError as e:
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
raise Exception(f"Missing module: {e}")
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
class MCPClient(BaseObject):
"""Client for Model Context Protocol (MCP) servers.
@@ -42,7 +44,7 @@ class MCPClient(BaseObject):
def __init__(
self,
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
server_params: ServerParameters,
**kwargs,
):
"""Initialize the MCP client with server parameters.

View File

@@ -25,6 +25,7 @@ from pydantic import BaseModel
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
from pipecat.frames.frames import (
CancelFrame,
ControlFrame,
EndFrame,
ErrorFrame,
Frame,
@@ -41,6 +42,7 @@ from pipecat.frames.frames import (
UserAudioRawFrame,
UserImageRawFrame,
UserImageRequestFrame,
DataFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
from pipecat.transcriptions.language import Language
@@ -105,6 +107,17 @@ class DailyInputTransportMessageUrgentFrame(InputTransportMessageUrgentFrame):
participant_id: Optional[str] = None
@dataclass
class DailyUpdateRemoteParticipantsFrame(ControlFrame):
"""Frame to update remote participants in Daily calls.
Parameters:
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
"""
remote_participants: Mapping[str, Any] = None
class WebRTCVADAnalyzer(VADAnalyzer):
"""Voice Activity Detection analyzer using WebRTC.
@@ -215,6 +228,7 @@ class DailyCallbacks(BaseModel):
on_active_speaker_changed: Called when the active speaker of the call has changed.
on_joined: Called when bot successfully joined a room.
on_left: Called when bot left a room.
on_before_leave: Called when bot is about to leave the room.
on_error: Called when an error occurs.
on_app_message: Called when receiving an app message.
on_call_state_updated: Called when call state changes.
@@ -244,6 +258,7 @@ class DailyCallbacks(BaseModel):
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
on_left: Callable[[], Awaitable[None]]
on_before_leave: Callable[[], Awaitable[None]]
on_error: Callable[[str], Awaitable[None]]
on_app_message: Callable[[Any, str], Awaitable[None]]
on_call_state_updated: Callable[[str], Awaitable[None]]
@@ -359,6 +374,7 @@ class DailyTransportClient(EventHandler):
self._transcription_ids = []
self._transcription_status = None
self._dial_out_session_id: str = ""
self._dial_in_session_id: str = ""
self._joining = False
self._joined = False
@@ -719,6 +735,9 @@ class DailyTransportClient(EventHandler):
logger.info(f"Leaving {self._room_url}")
# Call callback before leaving.
await self._callbacks.on_before_leave()
if self._params.transcription_enabled:
await self.stop_transcription()
@@ -823,6 +842,16 @@ class DailyTransportClient(EventHandler):
Args:
settings: SIP call transfer settings.
"""
session_id = (
settings.get("sessionId") or self._dial_out_session_id or self._dial_in_session_id
)
if not session_id:
logger.error("Unable to transfer SIP call: 'sessionId' is not set")
return
# Update 'sessionId' field.
settings["sessionId"] = session_id
future = self._get_event_loop().create_future()
self._client.sip_call_transfer(settings, completion=completion_callback(future))
await future
@@ -1141,6 +1170,7 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in connection data.
"""
self._dial_in_session_id = data["sessionId"] if "sessionId" in data else ""
self._call_event_callback(self._callbacks.on_dialin_connected, data)
def on_dialin_ready(self, sip_endpoint: str):
@@ -1157,6 +1187,9 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in stop data.
"""
# Cleanup only if our session stopped.
if data.get("sessionId") == self._dial_in_session_id:
self._dial_in_session_id = ""
self._call_event_callback(self._callbacks.on_dialin_stopped, data)
def on_dialin_error(self, data: Any):
@@ -1165,6 +1198,9 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in error data.
"""
# Cleanup only if our session errored out.
if data.get("sessionId") == self._dial_in_session_id:
self._dial_in_session_id = ""
self._call_event_callback(self._callbacks.on_dialin_error, data)
def on_dialin_warning(self, data: Any):
@@ -1199,7 +1235,7 @@ class DailyTransportClient(EventHandler):
data: Dial-out stop data.
"""
# Cleanup only if our session stopped.
if data["sessionId"] == self._dial_out_session_id:
if data.get("sessionId") == self._dial_out_session_id:
self._dial_out_session_id = ""
self._call_event_callback(self._callbacks.on_dialout_stopped, data)
@@ -1210,7 +1246,7 @@ class DailyTransportClient(EventHandler):
data: Dial-out error data.
"""
# Cleanup only if our session errored out.
if data["sessionId"] == self._dial_out_session_id:
if data.get("sessionId") == self._dial_out_session_id:
self._dial_out_session_id = ""
self._call_event_callback(self._callbacks.on_dialout_error, data)
@@ -1767,6 +1803,31 @@ class DailyOutputTransport(BaseOutputTransport):
# Leave the room.
await self._client.leave()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process outgoing frames, including transport messages.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
logger.debug(f"Got a DailyUpdateRemoteParticipantsFrame: {frame}")
await self._client.update_remote_participants(frame.remote_participants)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process outgoing frames, including transport messages.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
await self._client.update_remote_participants(frame.remote_participants)
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
"""Send a transport message to participants.
@@ -1862,6 +1923,7 @@ class DailyTransport(BaseTransport):
on_active_speaker_changed=self._on_active_speaker_changed,
on_joined=self._on_joined,
on_left=self._on_left,
on_before_leave=self._on_before_leave,
on_error=self._on_error,
on_app_message=self._on_app_message,
on_call_state_updated=self._on_call_state_updated,
@@ -1925,6 +1987,10 @@ class DailyTransport(BaseTransport):
self._register_event_handler("on_recording_started")
self._register_event_handler("on_recording_stopped")
self._register_event_handler("on_recording_error")
self._register_event_handler("on_before_disconnect", sync=True)
# Deprecated
self._register_event_handler("on_joined")
self._register_event_handler("on_left")
#
# BaseTransport
@@ -2176,6 +2242,10 @@ class DailyTransport(BaseTransport):
"""Handle room left events."""
await self._call_event_handler("on_left")
async def _on_before_leave(self):
"""Handle before leave room events."""
await self._call_event_handler("on_before_disconnect")
async def _on_error(self, error):
"""Handle error events and push error frames."""
await self._call_event_handler("on_error", error)
@@ -2315,7 +2385,7 @@ class DailyTransport(BaseTransport):
"""Handle participant updated events."""
await self._call_event_handler("on_participant_updated", participant)
async def _on_transcription_message(self, message: Dict[str, Any]) -> None:
async def _on_transcription_message(self, message: Mapping[str, Any]) -> None:
"""Handle transcription message events."""
await self._call_event_handler("on_transcription_message", message)

View File

@@ -114,6 +114,7 @@ class LiveKitCallbacks(BaseModel):
on_connected: Callable[[], Awaitable[None]]
on_disconnected: Callable[[], Awaitable[None]]
on_before_disconnect: Callable[[], Awaitable[None]]
on_participant_connected: Callable[[str], Awaitable[None]]
on_participant_disconnected: Callable[[str], Awaitable[None]]
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
@@ -282,6 +283,7 @@ class LiveKitTransportClient:
return
logger.info(f"Disconnecting from {self._room_name}")
await self._callbacks.on_before_disconnect()
await self.room.disconnect()
self._connected = False
logger.info(f"Disconnected from {self._room_name}")
@@ -918,6 +920,7 @@ class LiveKitTransport(BaseTransport):
callbacks = LiveKitCallbacks(
on_connected=self._on_connected,
on_disconnected=self._on_disconnected,
on_before_disconnect=self._on_before_disconnect,
on_participant_connected=self._on_participant_connected,
on_participant_disconnected=self._on_participant_disconnected,
on_audio_track_subscribed=self._on_audio_track_subscribed,
@@ -947,6 +950,7 @@ class LiveKitTransport(BaseTransport):
self._register_event_handler("on_first_participant_joined")
self._register_event_handler("on_participant_left")
self._register_event_handler("on_call_state_updated")
self._register_event_handler("on_before_disconnect", sync=True)
def input(self) -> LiveKitInputTransport:
"""Get the input transport for receiving media and events.
@@ -1041,6 +1045,10 @@ class LiveKitTransport(BaseTransport):
"""Handle room disconnected events."""
await self._call_event_handler("on_disconnected")
async def _on_before_disconnect(self):
"""Handle before disconnection room events."""
await self._call_event_handler("on_before_disconnect")
async def _on_participant_connected(self, participant_id: str):
"""Handle participant connected events."""
await self._call_event_handler("on_participant_connected", participant_id)

View File

@@ -95,15 +95,20 @@ class SmallWebRTCTrack:
enable/disable control and frame discarding for audio and video streams.
"""
def __init__(self, track: MediaStreamTrack):
def __init__(self, receiver):
"""Initialize the WebRTC track wrapper.
Args:
track: The underlying MediaStreamTrack to wrap.
index: The index of the track in the transceiver (0 for mic, 1 for cam, 2 for screen)
receiver: The RemoteStreamTrack receiver instance.
"""
self._track = track
self._receiver = receiver
# Configuring the receiver for not consuming the track by default to prevent memory grow
self._receiver._enabled = False
self._track = receiver.track
self._enabled = True
self._last_recv_time: float = 0.0
self._idle_task: Optional[asyncio.Task] = None
self._idle_timeout: float = 2.0 # seconds before discarding old frames
def set_enabled(self, enabled: bool) -> None:
"""Enable or disable the track.
@@ -138,13 +143,44 @@ class SmallWebRTCTrack:
async def recv(self) -> Optional[Frame]:
"""Receive the next frame from the track.
Enables the internal receiving state and starts idle watcher.
Returns:
The next frame, except for video tracks, where it returns the frame only if the track is enabled, otherwise, returns None.
"""
self._receiver._enabled = True
self._last_recv_time = time.time()
# start idle watcher if not already running
if not self._idle_task or self._idle_task.done():
self._idle_task = asyncio.create_task(self._idle_watcher())
if not self._enabled and self._track.kind == "video":
return None
return await self._track.recv()
async def _idle_watcher(self):
"""Disable receiving if idle for more than _idle_timeout and monitor queue size."""
while self._receiver._enabled:
await asyncio.sleep(self._idle_timeout)
idle_duration = time.time() - self._last_recv_time
if idle_duration >= self._idle_timeout:
# discard old frames to prevent memory growth
logger.debug(
f"Disabling receiver for {self._track.kind} track after {idle_duration:.2f}s idle"
)
await self.discard_old_frames()
self._receiver._enabled = False
def stop(self):
"""Stop receiving frames from the track."""
self._receiver._enabled = False
if self._idle_task:
self._idle_task.cancel()
self._idle_task = None
if self._track:
self._track.stop()
def __getattr__(self, name):
"""Forward attribute access to the underlying track.
@@ -454,6 +490,10 @@ class SmallWebRTCConnection(BaseObject):
async def _close(self):
"""Close the peer connection and cleanup resources."""
for track in self._track_map.values():
if track:
track.stop()
self._track_map.clear()
if self._pc:
await self._pc.close()
self._message_queue.clear()
@@ -526,8 +566,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No audio transceiver is available")
return None
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
audio_track = SmallWebRTCTrack(track) if track else None
receiver = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver
audio_track = SmallWebRTCTrack(receiver) if receiver else None
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
return audio_track
@@ -548,8 +588,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No video transceiver is available")
return None
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
receiver = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver
video_track = SmallWebRTCTrack(receiver) if receiver else None
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track
@@ -570,8 +610,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No screen video transceiver is available")
return None
track = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
receiver = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver
video_track = SmallWebRTCTrack(receiver) if receiver else None
self._track_map[SCREEN_VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track

View File

@@ -0,0 +1,200 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""SmallWebRTC request handler for managing peer connections.
This module provides a client for handling web requests and managing WebRTC connections.
"""
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Optional
from fastapi import HTTPException
from loguru import logger
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
@dataclass
class SmallWebRTCRequest:
"""Small WebRTC transport session arguments for the runner.
Parameters:
sdp: The SDP string (Session Description Protocol).
type: The type of the SDP, either "offer" or "answer".
pc_id: Optional identifier for the peer connection.
restart_pc: Optional whether to restart the peer connection.
request_data: Optional custom data sent by the customer.
"""
sdp: str
type: str
pc_id: Optional[str] = None
restart_pc: Optional[bool] = None
request_data: Optional[Any] = None
class ConnectionMode(Enum):
"""Enum defining the connection handling modes."""
SINGLE = "single" # Only one active connection allowed
MULTIPLE = "multiple" # Multiple simultaneous connections allowed
class SmallWebRTCRequestHandler:
"""SmallWebRTC request handler for managing peer connections.
This class is responsible for:
- Handling incoming SmallWebRTC requests.
- Creating and managing WebRTC peer connections.
- Supporting ESP32-specific SDP munging if enabled.
- Invoking callbacks for newly initialized connections.
- Supporting both single and multiple connection modes.
"""
def __init__(
self,
ice_servers: Optional[List[IceServer]] = None,
esp32_mode: bool = False,
host: Optional[str] = None,
connection_mode: ConnectionMode = ConnectionMode.MULTIPLE,
) -> None:
"""Initialize a SmallWebRTC request handler.
Args:
ice_servers (Optional[List[IceServer]]): List of ICE servers to use for WebRTC
connections.
esp32_mode (bool): If True, enables ESP32-specific SDP munging.
host (Optional[str]): Host address used for SDP munging in ESP32 mode.
Ignored if `esp32_mode` is False.
connection_mode (ConnectionMode): Mode of operation for handling connections.
SINGLE allows only one active connection, MULTIPLE allows several.
"""
self._ice_servers = ice_servers
self._esp32_mode = esp32_mode
self._host = host
self._connection_mode = connection_mode
# Store connections by pc_id
self._pcs_map: Dict[str, SmallWebRTCConnection] = {}
def _check_single_connection_constraints(self, pc_id: Optional[str]) -> None:
"""Check if the connection request satisfies single connection mode constraints.
Args:
pc_id: The peer connection ID from the request
Raises:
HTTPException: If constraints are violated in single connection mode
"""
if self._connection_mode != ConnectionMode.SINGLE:
return
if not self._pcs_map: # No existing connections
return
# Get the existing connection (should be only one in single mode)
existing_connection = next(iter(self._pcs_map.values()))
if existing_connection.pc_id != pc_id and pc_id:
logger.warning(
f"Connection pc_id mismatch: existing={existing_connection.pc_id}, received={pc_id}"
)
raise HTTPException(status_code=400, detail="PC ID mismatch with existing connection")
if not pc_id:
logger.warning(
"Cannot create new connection: existing connection found but no pc_id received"
)
raise HTTPException(
status_code=400,
detail="Cannot create new connection with existing connection active",
)
async def handle_web_request(
self,
request: SmallWebRTCRequest,
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
) -> None:
"""Handle a SmallWebRTC request and resolve the pending answer.
This method will:
- Reuse an existing WebRTC connection if `pc_id` exists.
- Otherwise, create a new `SmallWebRTCConnection`.
- Invoke the provided callback with the connection.
- Manage ESP32-specific munging if enabled.
- Enforce single/multiple connection mode constraints.
Args:
request (SmallWebRTCRequest): The incoming WebRTC request, containing
SDP, type, and optionally a `pc_id`.
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
asynchronous callback function that is invoked with the WebRTC connection.
Raises:
HTTPException: If connection mode constraints are violated
Exception: Any exception raised during request handling or callback execution
will be logged and propagated.
"""
try:
pc_id = request.pc_id
# Check connection mode constraints first
self._check_single_connection_constraints(pc_id)
# After constraints are satisfied, get the existing connection if any
existing_connection = self._pcs_map.get(pc_id) if pc_id else None
if existing_connection:
pipecat_connection = existing_connection
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request.sdp,
type=request.type,
restart_pc=request.restart_pc or False,
)
else:
pipecat_connection = SmallWebRTCConnection(ice_servers=self._ice_servers)
await pipecat_connection.initialize(sdp=request.sdp, type=request.type)
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
self._pcs_map.pop(webrtc_connection.pc_id, None)
# Invoke callback provided in runner arguments
try:
await webrtc_connection_callback(pipecat_connection)
logger.debug(
f"webrtc_connection_callback executed successfully for peer: {pipecat_connection.pc_id}"
)
except Exception as callback_error:
logger.error(
f"webrtc_connection_callback failed for peer {pipecat_connection.pc_id}: {callback_error}"
)
answer = pipecat_connection.get_answer()
if self._esp32_mode and self._host and self._host != "localhost":
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], self._host)
self._pcs_map[answer["pc_id"]] = pipecat_connection
return answer
except Exception as e:
logger.error(f"Error processing SmallWebRTC request: {e}")
logger.debug(f"SmallWebRTC request details: {request}")
raise
async def close(self):
"""Clear the connection map."""
coros = [pc.disconnect() for pc in self._pcs_map.values()]
await asyncio.gather(*coros)
self._pcs_map.clear()

View File

@@ -14,13 +14,33 @@ and async cleanup for all Pipecat components.
import asyncio
import inspect
from abc import ABC
from typing import Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from loguru import logger
from pipecat.utils.utils import obj_count, obj_id
@dataclass
class EventHandler:
"""Data class to store event handlers information.
This data class stores the event name, a list of handlers to run for this
event, and whether these handlers will be executed in a task.
Attributes:
name (str): The name of the event handler.
handlers (List[Any]): A list of functions to be called when this event is triggered.
is_sync (bool): Indicates whether the functions are executed in a task.
"""
name: str
handlers: List[Any]
is_sync: bool
class BaseObject(ABC):
"""Abstract base class providing common functionality for Pipecat objects.
@@ -41,7 +61,7 @@ class BaseObject(ABC):
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
# Registered event handlers.
self._event_handlers: dict = {}
self._event_handlers: Dict[str, EventHandler] = {}
# Set of tasks being executed. When a task finishes running it gets
# automatically removed from the set. When we cleanup we wait for all
@@ -103,18 +123,21 @@ class BaseObject(ABC):
Can be sync or async.
"""
if event_name in self._event_handlers:
self._event_handlers[event_name].append(handler)
self._event_handlers[event_name].handlers.append(handler)
else:
logger.warning(f"Event handler {event_name} not registered")
def _register_event_handler(self, event_name: str):
def _register_event_handler(self, event_name: str, sync: bool = False):
"""Register an event handler type.
Args:
event_name: The name of the event type to register.
sync: Whether this event handler will be executed in a task.
"""
if event_name not in self._event_handlers:
self._event_handlers[event_name] = []
self._event_handlers[event_name] = EventHandler(
name=event_name, handlers=[], is_sync=sync
)
else:
logger.warning(f"Event handler {event_name} not registered")
@@ -126,34 +149,43 @@ class BaseObject(ABC):
*args: Positional arguments to pass to event handlers.
**kwargs: Keyword arguments to pass to event handlers.
"""
# If we haven't registered an event handler, we don't need to do
# anything.
if not self._event_handlers.get(event_name):
if event_name not in self._event_handlers:
return
# Create the task.
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
event_handler = self._event_handlers[event_name]
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
for handler in event_handler.handlers:
if event_handler.is_sync:
# Just run the handler.
await self._run_handler(event_handler.name, handler, *args, **kwargs)
else:
# Create the task. Note that this is a task per each function
# handler. Users can register to an event handler multiple
# times.
task = asyncio.create_task(
self._run_handler(event_handler.name, handler, *args, **kwargs)
)
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
async def _run_task(self, event_name: str, *args, **kwargs):
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
async def _run_handler(self, event_name: str, handler, *args, **kwargs):
"""Execute all handlers for an event.
Args:
event_name: The name of the event being handled.
event_name: The event name for this handler.
handler: The handler function to run.
*args: Positional arguments to pass to handlers.
**kwargs: Keyword arguments to pass to handlers.
"""
try:
for handler in self._event_handlers[event_name]:
if inspect.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
handler(self, *args, **kwargs)
if inspect.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
handler(self, *args, **kwargs)
except Exception as e:
logger.exception(f"Exception in event handler {event_name}: {e}")

View File

@@ -0,0 +1,67 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import unittest
from pipecat.frames.frames import (
EndFrame,
Frame,
InterruptionFrame,
TextFrame,
TransportMessageUrgentFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import SleepFrame, run_test
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
async def test_interruption_and_wait(self):
class DelayFrameProcessor(FrameProcessor):
"""This processors just gives time to the event loop to change
between tasks. Otherwise things happen to fast."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await asyncio.sleep(0.1)
await self.push_frame(frame, direction)
class InterruptFrameProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_interruption_task_frame_and_wait()
await self.push_frame(TransportMessageUrgentFrame(message=frame.text))
else:
await self.push_frame(frame, direction)
pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()])
frames_to_send = [
# Just a random interruption to make sure we don't clear anything
# before the actual `InterruptionTaskFrame` interruption.
InterruptionFrame(),
# This will generate an `InterruptionTaskFrame` and will wait for an
# `InterruptionFrame`.
TextFrame(text="Hello from Pipecat!"),
# Just give time for everything to complete.
SleepFrame(sleep=0.5),
EndFrame(),
]
expected_down_frames = [
InterruptionFrame,
InterruptionFrame,
TransportMessageUrgentFrame,
EndFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
send_end_frame=False,
)

View File

@@ -0,0 +1,303 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Unit tests for ServiceSwitcher and related components."""
import unittest
from pipecat.frames.frames import (
Frame,
ManuallySwitchServiceFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import run_test
class MockFrameProcessor(FrameProcessor):
"""A test frame processor that tracks which frames it has processed."""
def __init__(self, test_name: str, **kwargs):
"""Initialize the test processor with a name.
Args:
test_name: A unique name for this processor instance.
**kwargs: Additional arguments passed to the parent FrameProcessor.
"""
super().__init__(name=test_name, **kwargs)
self.test_name = test_name
self.processed_frames = []
self.frame_count = 0
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process an incoming frame and track it.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
self.processed_frames.append(frame)
self.frame_count += 1
await self.push_frame(frame, direction)
def reset_counters(self):
"""Reset the frame tracking counters."""
self.processed_frames = []
self.frame_count = 0
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcherStrategyManual."""
def setUp(self):
"""Set up test fixtures."""
self.service1 = MockFrameProcessor("service1")
self.service2 = MockFrameProcessor("service2")
self.service3 = MockFrameProcessor("service3")
self.services = [self.service1, self.service2, self.service3]
def test_init_with_services(self):
"""Test initialization with a list of services."""
strategy = ServiceSwitcherStrategyManual(self.services)
self.assertEqual(strategy.services, self.services)
self.assertEqual(strategy.active_service, self.service1) # First service should be active
def test_init_with_empty_services(self):
"""Test initialization with an empty list of services."""
strategy = ServiceSwitcherStrategyManual([])
self.assertEqual(strategy.services, [])
self.assertIsNone(strategy.active_service)
def test_handle_manually_switch_service_frame(self):
"""Test manual service switching with ManuallySwitchServiceFrame."""
strategy = ServiceSwitcherStrategyManual(self.services)
# Initially service1 should be active
self.assertEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
# Switch to service2
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertEqual(strategy.active_service, self.service2)
self.assertNotEqual(strategy.active_service, self.service3)
# Switch to service3
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
self.assertEqual(strategy.active_service, self.service3)
def test_handle_frame_unsupported_frame_type(self):
"""Test that unsupported frame types raise an error."""
strategy = ServiceSwitcherStrategyManual(self.services)
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
with self.assertRaises(ValueError) as context:
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
self.assertIn("Unsupported frame type", str(context.exception))
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcher."""
def setUp(self):
"""Set up test fixtures."""
self.service1 = MockFrameProcessor("service1")
self.service2 = MockFrameProcessor("service2")
self.service3 = MockFrameProcessor("service3")
self.services = [self.service1, self.service2, self.service3]
def test_init_with_manual_strategy(self):
"""Test initialization with manual strategy."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
self.assertEqual(switcher.services, self.services)
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
self.assertEqual(switcher.strategy.services, self.services)
async def test_default_active_service(self):
"""Test that the initially-active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
# Reset counters
for service in self.services:
service.reset_counters()
# Send some test frames
frames_to_send = [
TextFrame(text="Hello 1"),
TextFrame(text="Hello 2"),
TextFrame(text="Hello 3"),
]
await run_test(
switcher,
frames_to_send=frames_to_send,
expected_down_frames=[TextFrame, TextFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
# Only service1 should have processed the text frames
# Note: The service also receives StartFrame and EndFrame, so count those too
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
self.assertEqual(len(text_frames), 3)
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
service3_text_frames = [
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
]
self.assertEqual(len(service2_text_frames), 0)
self.assertEqual(len(service3_text_frames), 0)
# Verify the actual text frames processed
for i, frame in enumerate(text_frames):
self.assertEqual(frame.text, f"Hello {i + 1}")
async def test_service_switching(self):
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
# Reset counters
for service in self.services:
service.reset_counters()
# Send a test frame, a switch frame, and another test frame
await run_test(
switcher,
frames_to_send=[
TextFrame("Hello 1"),
ManuallySwitchServiceFrame(service=self.service2),
TextFrame("Hello 2"),
],
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
# Verify service2 received the frame
service1_text_frames = [
f for f in self.service1.processed_frames if isinstance(f, TextFrame)
]
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
service3_text_frames = [
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
]
self.assertEqual(len(service1_text_frames), 1)
self.assertEqual(len(service2_text_frames), 1)
self.assertEqual(len(service3_text_frames), 0)
self.assertEqual(service1_text_frames[0].text, "Hello 1")
self.assertEqual(service2_text_frames[0].text, "Hello 2")
async def test_multi_service_switcher_targeting(self):
"""Test that ManuallySwitchServiceFrame targets the correct ServiceSwitcher in a multi-switcher pipeline."""
# Create services for first switcher
switcher1_service1 = MockFrameProcessor("switcher1_service1")
switcher1_service2 = MockFrameProcessor("switcher1_service2")
switcher1_services = [switcher1_service1, switcher1_service2]
# Create services for second switcher
switcher2_service1 = MockFrameProcessor("switcher2_service1")
switcher2_service2 = MockFrameProcessor("switcher2_service2")
switcher2_services = [switcher2_service1, switcher2_service2]
# Create two service switchers
switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual)
switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual)
# Create a pipeline with both switchers: switcher1 -> switcher2
pipeline = Pipeline([switcher1, switcher2])
# Reset counters
for service in switcher1_services + switcher2_services:
service.reset_counters()
# Initially, both switchers should use their first services
self.assertEqual(switcher1.strategy.active_service, switcher1_service1)
self.assertEqual(switcher2.strategy.active_service, switcher2_service1)
# Send frames to test the pipeline:
# 1. Text frame (should go through both switchers' active services)
# 2. Switch frame targeting switcher1's second service
# 3. Text frame (should go through switcher1's new service and switcher2's original service)
# 4. Switch frame targeting switcher2's second service
# 5. Text frame (should go through switcher1's current service and switcher2's new service)
await run_test(
pipeline,
frames_to_send=[
TextFrame("Before any switches"),
ManuallySwitchServiceFrame(service=switcher1_service2), # Switch first switcher
TextFrame("After switching first switcher"),
ManuallySwitchServiceFrame(service=switcher2_service2), # Switch second switcher
TextFrame("After switching second switcher"),
],
expected_down_frames=[
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
],
expected_up_frames=[], # Expect no error frames
)
# Verify the active services changed correctly
self.assertEqual(switcher1.strategy.active_service, switcher1_service2)
self.assertEqual(switcher2.strategy.active_service, switcher2_service2)
# Verify frame distribution:
# First text frame should go through switcher1_service1 and switcher2_service1
switcher1_service1_texts = [
f for f in switcher1_service1.processed_frames if isinstance(f, TextFrame)
]
switcher2_service1_texts = [
f for f in switcher2_service1.processed_frames if isinstance(f, TextFrame)
]
# Second text frame should go through switcher1_service2 and switcher2_service1
switcher1_service2_texts = [
f for f in switcher1_service2.processed_frames if isinstance(f, TextFrame)
]
# Third text frame should go through switcher1_service2 and switcher2_service2
switcher2_service2_texts = [
f for f in switcher2_service2.processed_frames if isinstance(f, TextFrame)
]
# Verify frame counts and content
self.assertEqual(len(switcher1_service1_texts), 1)
self.assertEqual(switcher1_service1_texts[0].text, "Before any switches")
self.assertEqual(len(switcher1_service2_texts), 2)
self.assertEqual(switcher1_service2_texts[0].text, "After switching first switcher")
self.assertEqual(switcher1_service2_texts[1].text, "After switching second switcher")
self.assertEqual(len(switcher2_service1_texts), 2)
self.assertEqual(switcher2_service1_texts[0].text, "Before any switches")
self.assertEqual(switcher2_service1_texts[1].text, "After switching first switcher")
self.assertEqual(len(switcher2_service2_texts), 1)
self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher")
if __name__ == "__main__":
unittest.main()

2604
uv.lock generated

File diff suppressed because it is too large Load Diff