Compare commits

..

8 Commits

Author SHA1 Message Date
James Hush
29d4a56663 Working on the 46 example 2025-09-17 11:59:16 +08:00
James Hush
373a09ecd6 Working on the 46 example 2025-09-17 11:59:10 +08:00
James Hush
07f54c48f3 This is working 2025-09-17 11:53:07 +08:00
James Hush
c8a3d65aa4 Save progress 2025-09-17 11:39:21 +08:00
James Hush
50a2a0dc86 ok its kinda working 2025-09-17 11:29:11 +08:00
James Hush
0421d97954 Save changes 2025-09-17 11:09:03 +08:00
James Hush
54c8f336c3 Save progress 2025-09-16 16:43:38 +08:00
James Hush
b086fbafe6 feat: Add OpenAI Agents SDK integration service
- Create new OpenAIAgentService that integrates OpenAI Agents SDK with Pipecat
- Support for agent loops, handoffs, guardrails, and session management
- Add streaming and non-streaming response modes
- Include comprehensive tool integration and error handling
- Add optional dependency for openai-agents package
- Create foundational examples showing basic usage and agent handoffs
- Add comprehensive tests with mocked dependencies
- Include detailed documentation and README

Key features:
- Real-time streaming responses compatible with Pipecat pipelines
- Agent handoffs for specialized task delegation
- Tool calling with automatic schema generation
- Input/output guardrails for safety and validation
- Session context management for conversation continuity
- Built-in tracing and monitoring integration

Examples:
- 45-openai-agent-basic.py: Basic agent with weather and trivia tools
- 46-openai-agent-handoffs.py: Multi-agent system with specialist handoffs
2025-09-16 16:20:30 +08:00
35 changed files with 6245 additions and 5220 deletions

285
AGENTS.md Normal file
View File

@@ -0,0 +1,285 @@
# AGENTS.md
## Project Overview
Pipecat is an open-source Python framework for building real-time voice and multimodal conversational AI agents. The codebase is organized around a pipeline architecture where data flows through connected services (STT → LLM → TTS).
## Development Environment Setup
### Prerequisites
- **Minimum Python Version:** 3.10
- **Recommended Python Version:** 3.12
- **Package Manager:** uv (recommended) or pip
### Setup Commands
```bash
# Clone the repository
git clone https://github.com/pipecat-ai/pipecat.git
cd pipecat
# Install dependencies with uv (recommended)
uv sync --group dev --all-extras \
--no-extra gstreamer \
--no-extra krisp \
--no-extra local \
--no-extra ultravox
# Or with pip
pip install -e ".[dev]"
# Install pre-commit hooks
uv run pre-commit install
# Set up environment variables
cp env.example .env
```
## Build and Test Commands
### Running Tests
```bash
# Run all tests
uv run pytest
# Run specific test file
uv run pytest tests/test_name.py
# Run tests with coverage
uv run pytest --cov=pipecat --cov-report=html
```
### Code Quality
```bash
# Format code (required before commits)
uv run ruff format
# Lint code
uv run ruff check
# Type checking
uv run mypy src/pipecat
# Run pre-commit checks manually
uv run pre-commit run --all-files
```
### Documentation
```bash
# Build API documentation
cd docs/api
./build-docs.sh
# Build docs manually
sphinx-build -b html . _build/html -W --keep-going
```
## Code Style Guidelines
### Python Standards
- **Formatting:** Strict PEP 8 via Ruff
- **Docstrings:** Google-style format
- **Type Hints:** Required for all public APIs
- **Import Organization:** Automated via Ruff
### Docstring Conventions
- **Classes:** Describe purpose + `__init__` with complete `Args:` section
- **Dataclasses:** Use `Parameters:` section, no `__init__` docstring
- **Methods:** Include `Args:` and `Returns:` sections
- **Properties:** Must have `Returns:` section
- **Examples:** Use `Examples:` section with `::` syntax
### File Organization
```
src/pipecat/ # Main package
├── processors/ # Frame processors
├── services/ # AI service integrations
├── transports/ # Communication layers
├── frames/ # Data frame definitions
└── pipeline/ # Pipeline orchestration
examples/foundational/ # Step-by-step tutorials
tests/ # Test suite
```
## Testing Instructions
### Test Structure
- **Unit Tests:** Test individual components in isolation
- **Integration Tests:** Test service interactions
- **Example Tests:** Validate foundational examples work
### Adding Tests
```bash
# Test naming convention
test_<component>_<functionality>.py
# Run specific test pattern
uv run pytest -k "test_pipeline"
# Run with debugging
uv run pytest -s -vv tests/test_name.py::test_function
```
### Pre-commit Requirements
All commits must pass:
- Ruff formatting
- Ruff linting
- Type checking
- Basic test suite
## Dependency Management
### Using uv (Recommended)
```bash
# Add runtime dependency
uv add package-name
# Add optional dependency
uv add --optional service package-name
# Add development dependency
uv add --group dev package-name
# Update lockfile
uv lock
# Sync dependencies
uv sync
```
### Important Notes
- **Always commit both `pyproject.toml` and `uv.lock` together**
- **Never manually edit `uv.lock`** - it's auto-generated
- **Use extras for optional service dependencies** (e.g., `[openai]`, `[cartesia]`)
## Project Structure Guidelines
### Service Integration
When adding new AI services:
1. Create service class in `src/pipecat/services/<provider>/`
2. Follow existing patterns (e.g., STTService, LLMService)
3. Add to appropriate extras in `pyproject.toml`
4. Include tests in `tests/`
5. Add documentation examples
### Frame Processing
For custom processors:
1. Inherit from `FrameProcessor`
2. Implement `process_frame()` method. ALWAYS explicitly call `await super().process_frame(frame, direction)` at the top of this method.
3. Handle frame direction (FrameDirection.UPSTREAM/DOWNSTREAM)
4. Add proper type hints and docstrings
### Transport Implementation
For new transport layers:
1. Inherit from `BaseTransport`
2. Implement required abstract methods
3. Handle connection lifecycle
4. Support both input and output streams
## Security Considerations
### API Keys
- **Never commit API keys** to the repository
- **Use environment variables** for all secrets
- **Reference `env.example`** for required variables
- **Use `.env` files** for local development
### Input Validation
- **Validate all external inputs** (audio, text, API responses)
- **Sanitize user data** before processing
- **Handle rate limiting** for external services
- **Implement proper timeout handling**
## Performance Guidelines
### Memory Management
- **Clean up resources** in transport disconnection handlers
- **Use async context managers** for service connections
- **Implement proper frame lifecycle** management
### Latency Optimization
- **Choose appropriate STT services** for latency requirements
- **Use streaming TTS** when possible
- **Implement connection pooling** for HTTP services
- **Consider WebRTC** for real-time applications
## Common Patterns
### Error Handling
```python
@transport.event_handler("on_error")
async def on_error(transport, error):
logger.error(f"Transport error: {error}")
# Shutdown the pipeline
await task.queue_frame(EndFrame())
```
### Service Configuration
```python
# Use environment variables for configuration
service = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY", ""),
model="gpt-4o",
params={"temperature": 0.7}
)
```
### Pipeline Assembly
```python
pipeline = Pipeline([
transport.input(),
stt_service,
context_aggregator.user(),
llm_service,
tts_service,
transport.output(),
context_aggregator.assistant(),
])
```
## Commit and PR Guidelines
### Commit Message Format
```
<type>(<scope>): <description>
[optional body]
[optional footer]
```
Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
### PR Requirements
- **All tests must pass**
- **Code must be properly formatted** (Ruff)
- **Include appropriate tests** for new functionality
- **Update documentation** if needed
- **Reference related issues** in description
### Review Process
1. Automated checks must pass
2. Manual code review by maintainers
3. Documentation review for user-facing changes
4. Integration testing for service additions
## Troubleshooting
### Common Issues
- **Import errors:** Run `uv sync` to ensure dependencies are installed
- **Test failures:** Check environment variables in `.env`
- **Format errors:** Run `uv run ruff format` before committing
- **Type errors:** Ensure all public methods have type hints
### Development Tips
- **Use foundational examples** as starting points for testing
- **Check existing services** for integration patterns
- **Run tests frequently** during development
- **Use IDE integration** for Ruff formatting
### Getting Help
- **Documentation:** [docs.pipecat.ai](https://docs.pipecat.ai)
- **Issues:** [GitHub Issues](https://github.com/pipecat-ai/pipecat/issues)

View File

@@ -9,23 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `on_before_disconnect` synchronous event to `DailyTransport` and
`LiveKitTransport`.
- It is now possible to register synchronous event handlers. By default, all
event handlers are executed in a separate task. However, in some cases we want
to guarantee order of execution, for example, executing something before
disconnecting a transport.
```python
self._register_event_handler("on_event_name", sync=True)
```
- Added support for global location in `GoogleVertexLLMService`. The service now
supports both regional locations (e.g., "us-east4") and the "global" location
for Vertex AI endpoints. When using "global" location, the service will use
`aiplatform.googleapis.com` as the API host instead of the regional format.
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
fired when the pipeline is done running. This can be the result of a
`StopFrame`, `CancelFrame` or `EndFrame`.
@@ -36,64 +19,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
...
```
### Changed
- Updated Silero VAD model to v6.
- Updated `livekit` to 1.0.13.
- `torch` and `torchaudio` are no longer required for running Smart Turn
locally. This avoids gigabytes of dependencies being installed.
- Updated `websockets` dependency to support version 15.0. Removed deprecated
usage of `ConnectionClosed.code` and `ConnectionClosed.reason` attributes in
`AWSTranscribeSTTService` for compatibility.
- Refactored `pyproject.toml` to reduce websockets dependency repetition using
self-referencing extras. All websockets-dependent services now reference a
shared `websockets-base` extra.
### Deprecated
- `GladiaSTTService`'s `confidence` arg is deprecated. `confidence` is no
longer needed to determine which transcription or translation frames to
emit.
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
instead.
### Fixed
- Fixed an issue where multiple handlers for an event would not run in parallel.
- Fixed `DailyTransport.sip_call_transfer()` to automatically use the session
ID from the `on_dialin_connected` event, when not explicitly provided. Now
supports cold transfers (from incoming dial-in calls) by automatically
tracking session IDs from connection events.
- Fixed a memory leak in `SmallWebRTCTransport`. In `aiortc`, when you receive
a `MediaStreamTrack` (audio or video), frames are produced asynchronously. If
the code never consumes these frames, they are queued in memory, causing a
memory leak.
- Fixed an issue in `AsyncAITTSService`, where `TTSTextFrames` were not being
pushed.
- Fixed an issue that would cause `push_interruption_task_frame_and_wait()` to
not wait if a previous interruption had already happened.
- Fixed a couple of bugs in `ServiceSwitcher`:
- Using multiple `ServiceSwitcher`s in a pipeline would result in an error.
- `ServiceSwitcherFrame`s (such as `ManuallySwitchServiceFrame`s) were having
an effect too early, essentially "jumping the queue" in terms of pipeline
frame ordering.
- Fixed a self-cancellation deadlock in `UserIdleProcessor` when returning
`False` from an idle callback. The task now terminates naturally instead of
attempting to cancel itself.
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
when a bot speaks and user input is blocked.

View File

@@ -21,8 +21,6 @@
🧭 Looking to build structured conversations? Check out [Pipecat Flows](https://github.com/pipecat-ai/pipecat-flows) for managing complex conversational states and transitions.
🔍 Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
## 🧠 Why Pipecat?
- **Voice-first**: Integrates speech recognition, text-to-speech, and conversation handling

View File

@@ -11,7 +11,7 @@ import sys
from dotenv import load_dotenv
from loguru import logger
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.frames.frames import TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
@@ -50,7 +50,7 @@ async def main():
async def on_first_participant_joined(transport, participant_id):
await asyncio.sleep(1)
await task.queue_frame(
TTSSpeakFrame(
TextFrame(
"Hello there! How are you doing today? Would you like to talk about the weather?"
)
)

View File

@@ -30,6 +30,10 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
# to the Smart Turn v3 ONNX model file.
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -38,19 +42,25 @@ transport_params = {
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
turn_analyzer=LocalSmartTurnAnalyzerV3(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
turn_analyzer=LocalSmartTurnAnalyzerV3(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
turn_analyzer=LocalSmartTurnAnalyzerV3(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),
}

View File

@@ -0,0 +1,205 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""
Basic OpenAI Agent service example.
This example demonstrates how to use the OpenAI Agents SDK within a Pipecat
pipeline to create an interactive agent with tool calling capabilities.
Requirements:
- OpenAI API key
- OpenAI Agents SDK: pip install openai-agents
"""
import os
import random
from typing import Any, List
# Import agents SDK for tools and agent creation
from agents import Agent, function_tool
from dotenv import load_dotenv
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from pipecat.frames.frames import LLMRunFrame, TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# Transport configuration
transport_params = {
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
}
@function_tool
def get_weather(location: str) -> str:
"""Get the current weather for a location.
Args:
location: The location to get weather for
Returns:
A weather description string
"""
# Mock weather data - in real usage, integrate with weather API
weather_data = {
"San Francisco": "Foggy, 65°F",
"New York": "Sunny, 72°F",
"London": "Rainy, 59°F",
"Tokyo": "Partly cloudy, 68°F",
}
return weather_data.get(location, f"Weather data not available for {location}")
@function_tool
def get_random_fact() -> str:
"""Get a random interesting fact.
Returns:
A random fact string
"""
facts = [
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
"Octopuses have three hearts and blue blood.",
"The Great Wall of China isn't visible from space with the naked eye.",
"Bananas are berries, but strawberries aren't.",
]
return random.choice(facts)
def get_random_fact_tool():
"""Example tool function for random facts."""
def get_random_fact() -> str:
"""Get a random interesting fact.
Returns:
A random fact string.
"""
facts = [
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
"A group of flamingos is called a 'flamboyance'.",
"Octopuses have three hearts and blue blood.",
"The Great Wall of China isn't visible from space with the naked eye.",
"Bananas are berries, but strawberries aren't.",
]
return random.choice(facts)
return get_random_fact
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info("Starting OpenAI Agent bot")
# Set up STT for speech recognition
stt = DeepgramSTTService(
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
model="nova-2",
)
# Set up TTS for voice output
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY", ""),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
# Create tools for the agent
tools: list[Any] = [
get_weather,
get_random_fact,
]
# Create the agent with tools
agent = Agent(
name="Assistant",
instructions="""You are a helpful assistant with access to weather information and random facts.
You can:
- Check weather for any location using the get_weather tool
- Share interesting facts using the get_random_fact tool
- Have natural conversations
Be friendly, informative, and engaging in your responses.""",
tools=tools,
)
# Initialize the OpenAI Agent service with the pre-configured agent
agent_service = OpenAIAgentService(
agent=agent,
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Set up conversation context with initial system message
messages: List[ChatCompletionMessageParam] = [
{
"role": "system",
"content": "You are a helpful assistant with access to weather information and random facts. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = agent_service.create_context_aggregator(context)
# Create the processing pipeline with context aggregators
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # Speech to text
context_aggregator.user(), # User responses
agent_service, # OpenAI Agent processing
tts, # Text to speech
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
# Send an initial greeting when client connects
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info("Client connected, sending greeting")
# Kick off the conversation by adding system message and running LLM
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info("Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -0,0 +1,276 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""
Advanced OpenAI Agent service example with handoffs.
This example demonstrates how to use multiple agents with handoffs in the
OpenAI Agents SDK within a Pipecat pipeline, showcasing agent orchestration
and specialization.
Requirements:
- OpenAI API key
- OpenAI Agents SDK: pip install openai-agents
"""
import os
import random
from typing import Any, Dict, List
from dotenv import load_dotenv
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from pipecat.frames.frames import LLMRunFrame, TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# Transport configuration
transport_params = {
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
}
def create_weather_tools():
"""Create weather-related tools."""
def get_weather(location: str) -> str:
"""Get current weather for a location."""
conditions = ["sunny", "cloudy", "rainy", "snowy", "windy"]
temp = random.randint(-10, 35)
condition = random.choice(conditions)
return f"The weather in {location} is {condition} with a temperature of {temp}°C."
def get_forecast(location: str, days: int = 3) -> str:
"""Get weather forecast for multiple days."""
forecast = []
for i in range(days):
conditions = ["sunny", "cloudy", "rainy", "snowy"]
temp = random.randint(-5, 30)
condition = random.choice(conditions)
day = "today" if i == 0 else f"in {i} day{'s' if i > 1 else ''}"
forecast.append(f"{day.capitalize()}: {condition}, {temp}°C")
return f"Weather forecast for {location}:\n" + "\n".join(forecast)
return [get_weather, get_forecast]
def create_trivia_tools():
"""Create trivia and fact tools."""
def get_random_fact() -> str:
"""Get a random interesting fact."""
facts = [
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
"A group of flamingos is called a 'flamboyance'.",
"Octopuses have three hearts and blue blood.",
"The Great Wall of China isn't visible from space with the naked eye.",
"Bananas are berries, but strawberries aren't.",
"Wombat poop is cube-shaped.",
"A shrimp's heart is in its head.",
"It's impossible to hum while holding your nose.",
]
return random.choice(facts)
def get_science_fact() -> str:
"""Get a random science fact."""
facts = [
"The speed of light in a vacuum is approximately 299,792,458 meters per second.",
"DNA stands for Deoxyribonucleic Acid.",
"The human brain uses about 20% of the body's total energy.",
"There are more possible games of chess than atoms in the observable universe.",
"A single bolt of lightning contains enough energy to toast 100,000 slices of bread.",
]
return random.choice(facts)
return [get_random_fact, get_science_fact]
def create_math_tools():
"""Create math calculation tools."""
def calculate(expression: str) -> str:
"""Safely calculate a mathematical expression."""
try:
# Only allow basic math operations for safety
allowed_chars = set("0123456789+-*/.() ")
if not all(c in allowed_chars for c in expression):
return "Sorry, I can only calculate basic math expressions with +, -, *, /, and parentheses."
result = eval(expression)
return f"{expression} = {result}"
except Exception as e:
return f"Error calculating '{expression}': {str(e)}"
def generate_math_problem() -> str:
"""Generate a random math problem."""
operations = ["+", "-", "*"]
a = random.randint(1, 20)
b = random.randint(1, 20)
op = random.choice(operations)
if op == "+":
answer = a + b
elif op == "-":
answer = a - b
else: # multiplication
answer = a * b
return f"Here's a math problem for you: {a} {op} {b} = ?"
return [calculate, generate_math_problem]
async def create_specialist_agents():
"""Create specialized agents for different domains."""
# Weather specialist agent
weather_agent = OpenAIAgentService(
name="Weather Specialist",
instructions="""You are a weather specialist. You provide detailed weather information,
forecasts, and weather-related advice. Use your tools to get accurate weather data.
Be informative and helpful about weather conditions and what they might mean for
outdoor activities.""",
tools=create_weather_tools(),
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Trivia specialist agent
trivia_agent = OpenAIAgentService(
name="Trivia Master",
instructions="""You are a trivia and facts specialist. You love sharing interesting
facts, trivia, and educational content. Use your tools to provide fascinating
information and engage users with fun facts. Make learning enjoyable!""",
tools=create_trivia_tools(),
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Math specialist agent
math_agent = OpenAIAgentService(
name="Math Helper",
instructions="""You are a mathematics specialist. You help with calculations,
math problems, and mathematical concepts. Use your tools to solve problems
and generate practice questions. Make math accessible and fun!""",
tools=create_math_tools(),
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
return weather_agent, trivia_agent, math_agent
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info("Starting OpenAI Agent bot with handoffs")
# Set up STT for speech recognition
stt = DeepgramSTTService(
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
model="nova-2",
)
# Set up TTS for voice output
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY", ""),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
# Create specialist agents
weather_agent, trivia_agent, math_agent = await create_specialist_agents()
# Create the main triage agent that can hand off to specialists
triage_agent = OpenAIAgentService(
name="Assistant Coordinator",
instructions="""You are a helpful assistant coordinator. Your role is to understand
what the user needs and direct them to the right specialist:
- For weather questions, forecasts, or outdoor activity planning -> Weather Specialist
- For interesting facts, trivia, or educational content -> Trivia Master
- For calculations, math problems, or mathematical help -> Math Helper
If the request doesn't clearly fit a specialist, you can handle general conversation
yourself. Always be friendly and explain when you're connecting them to a specialist.""",
handoffs=[weather_agent.agent, trivia_agent.agent, math_agent.agent], # type: ignore
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Set up conversation context with initial system message
messages: List[ChatCompletionMessageParam] = [
{
"role": "system",
"content": "You are a helpful assistant coordinator with access to weather information, trivia, and math tools. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = triage_agent.create_context_aggregator(context)
# Create the processing pipeline with context aggregators
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # Speech to text
context_aggregator.user(), # User responses
triage_agent, # OpenAI Agent processing
tts, # Text to speech
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
# Send an initial greeting when client connects
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info("Client connected, sending greeting")
# Kick off the conversation by adding system message and running LLM
messages.append(
{
"role": "system",
"content": "Please introduce yourself to the user as an AI assistant coordinator who works with specialists for weather, trivia, and math topics.",
}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info("Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -34,7 +34,7 @@ dependencies = [
"pyloudnorm~=0.1.1",
"resampy~=0.4.3",
"soxr~=0.5.0",
"openai>=1.74.0,<=1.99.1",
"openai>=1.74.0,<2.0.0",
# Pinning numba to resolve package dependencies
"numba==0.61.2",
"wait_for2>=0.4.1; python_version<'3.12'",
@@ -47,68 +47,68 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
aic = [ "aic-sdk~=1.0.1" ]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "pipecat-ai[websockets-base]" ]
asyncai = [ "pipecat-ai[websockets-base]" ]
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
assemblyai = [ "websockets>=13.1,<15.0" ]
asyncai = [ "websockets>=13.1,<15.0" ]
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
cerebras = []
deepseek = []
daily = [ "daily-python~=0.19.9" ]
deepgram = [ "deepgram-sdk~=4.7.0" ]
elevenlabs = [ "pipecat-ai[websockets-base]" ]
elevenlabs = [ "websockets>=13.1,<15.0" ]
fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
gladia = [ "pipecat-ai[websockets-base]" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "pipecat-ai[websockets-base]" ]
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
gladia = [ "websockets>=13.1,<15.0" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
heygen = [ "livekit>=0.22.0", "websockets>=13.1,<15.0" ]
inworld = []
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
koala = [ "pvkoala~=2.0.3" ]
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "pipecat-ai[websockets-base]" ]
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "websockets>=13.1,<15.0" ]
local = [ "pyaudio~=0.2.14" ]
mcp = [ "mcp[cli]~=1.9.4" ]
mcp = [ "mcp[cli]>=1.11.0,<2.0.0" ]
mem0 = [ "mem0ai~=0.1.94" ]
mistral = []
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
nim = []
neuphonic = [ "pipecat-ai[websockets-base]" ]
neuphonic = [ "websockets>=13.1,<15.0" ]
noisereduce = [ "noisereduce~=3.0.3" ]
openai = [ "pipecat-ai[websockets-base]" ]
openpipe = [ "openpipe~=4.50.0" ]
openai = [ "websockets>=13.1,<15.0" ]
openai-agent = [ "openai-agents~=0.3.0" ]
# openpipe = [ "openpipe~=4.50.0" ] # Temporarily disabled due to openai version conflict
openrouter = []
perplexity = []
playht = [ "pipecat-ai[websockets-base]" ]
playht = [ "websockets>=13.1,<15.0" ]
qwen = []
rime = [ "pipecat-ai[websockets-base]" ]
rime = [ "websockets>=13.1,<15.0" ]
riva = [ "nvidia-riva-client~=2.21.1" ]
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
sambanova = []
sarvam = [ "pipecat-ai[websockets-base]" ]
sarvam = [ "websockets>=13.1,<15.0" ]
sentry = [ "sentry-sdk~=2.23.1" ]
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1, <2" ]
local-smart-turn-v3 = [ "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3", "onnxruntime>=1.20.1, <2" ]
remote-smart-turn = []
silero = [ "onnxruntime>=1.20.1, <2" ]
simli = [ "simli-ai~=0.1.10"]
soniox = [ "pipecat-ai[websockets-base]" ]
soniox = [ "websockets>=13.1,<15.0" ]
soundfile = [ "soundfile~=0.13.0" ]
speechmatics = [ "speechmatics-rt>=0.4.0" ]
tavus=[]
together = []
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
webrtc = [ "aiortc~=1.13.0", "opencv-python~=4.11.0.86" ]
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.117.0" ]
websockets-base = [ "websockets>=13.1,<16.0" ]
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
whisper = [ "faster-whisper~=1.1.1" ]
[dependency-groups]

View File

@@ -1,12 +0,0 @@
#!/bin/bash
PID=$1
while true; do
# Clear the screen
clear
# Print the header + RSS in GB
ps -p "$PID" -o pid,comm,rss | \
awk 'NR==1 {print $0, "rss_GB"} NR>1 {printf "%s %s %s %.2f\n", $1,$2,$3,$3/1024/1024}'
sleep 1
done

View File

@@ -98,15 +98,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
inputs = self._feature_extractor(
audio_array,
sampling_rate=16000,
return_tensors="np",
return_tensors="pt",
padding="max_length",
max_length=8 * 16000,
truncation=True,
do_normalize=True,
)
# Extract features and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
# Convert to numpy and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32)
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
# Run ONNX inference

View File

@@ -1604,7 +1604,7 @@ class MixerEnableFrame(MixerControlFrame):
@dataclass
class ServiceSwitcherFrame(ControlFrame):
"""A base class for frames that affect ServiceSwitcher behavior."""
"""A base class for frames that control ServiceSwitcher behavior."""
pass

View File

@@ -6,15 +6,9 @@
"""Service switcher for switching between different services at runtime, with different switching strategies."""
from dataclasses import dataclass
from typing import Any, Generic, List, Optional, Type, TypeVar
from pipecat.frames.frames import (
ControlFrame,
Frame,
ManuallySwitchServiceFrame,
ServiceSwitcherFrame,
)
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, ServiceSwitcherFrame
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -28,6 +22,19 @@ class ServiceSwitcherStrategy:
self.services = services
self.active_service: Optional[FrameProcessor] = None
def is_active(self, service: FrameProcessor) -> bool:
"""Determine if the given service is the currently active one.
This method should be overridden by subclasses to implement specific logic.
Args:
service: The service to check.
Returns:
True if the given service is the active one, False otherwise.
"""
raise NotImplementedError("Subclasses must implement this method.")
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
"""Handle a frame that controls service switching.
@@ -53,6 +60,17 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
super().__init__(services)
self.active_service = services[0] if services else None
def is_active(self, service: FrameProcessor) -> bool:
"""Check if the given service is the currently active one.
Args:
service: The service to check.
Returns:
True if the given service is the active one, False otherwise.
"""
return service == self.active_service
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
"""Handle a frame that controls service switching.
@@ -61,21 +79,20 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
direction: The direction of the frame (upstream or downstream).
"""
if isinstance(frame, ManuallySwitchServiceFrame):
self._set_active_if_available(frame.service)
self._set_active(frame.service)
else:
raise ValueError(f"Unsupported frame type: {type(frame)}")
def _set_active_if_available(self, service: FrameProcessor):
"""Set the active service to the given one, if it is in the list of available services.
If it's not in the list, the request is ignored, as it may have been
intended for another ServiceSwitcher in the pipeline.
def _set_active(self, service: FrameProcessor):
"""Set the active service to the given one.
Args:
service: The service to set as active.
"""
if service in self.services:
self.active_service = service
else:
raise ValueError(f"Service {service} is not in the list of available services.")
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
@@ -91,43 +108,6 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
self.services = services
self.strategy = strategy
class ServiceSwitcherFilter(FunctionFilter):
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
def __init__(
self,
wrapped_service: FrameProcessor,
active_service: FrameProcessor,
direction: FrameDirection,
):
"""Initialize the service switcher filter with a strategy and direction."""
async def filter(_: Frame) -> bool:
return self._wrapped_service == self._active_service
super().__init__(filter, direction)
self._wrapped_service = wrapped_service
self._active_service = active_service
async def process_frame(self, frame, direction):
"""Process a frame through the filter, handling special internal filter-updating frames."""
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
self._active_service = frame.active_service
# Two ServiceSwitcherFilters "sandwich" a service. Push the
# frame only to update the other side of the sandwich, but
# otherwise don't let it leave the sandwich.
if direction == self._direction:
await self.push_frame(frame, direction)
return
await super().process_frame(frame, direction)
@dataclass
class ServiceSwitcherFilterFrame(ControlFrame):
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
active_service: FrameProcessor
@staticmethod
def _make_pipeline_definitions(
services: List[FrameProcessor], strategy: ServiceSwitcherStrategy
@@ -141,18 +121,14 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
def _make_pipeline_definition(
service: FrameProcessor, strategy: ServiceSwitcherStrategy
) -> Any:
async def filter(frame) -> bool:
_ = frame
return strategy.is_active(service)
return [
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
direction=FrameDirection.DOWNSTREAM,
),
FunctionFilter(filter, direction=FrameDirection.DOWNSTREAM),
service,
ServiceSwitcher.ServiceSwitcherFilter(
wrapped_service=service,
active_service=strategy.active_service,
direction=FrameDirection.UPSTREAM,
),
FunctionFilter(filter, direction=FrameDirection.UPSTREAM),
]
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -166,7 +142,3 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
if isinstance(frame, ServiceSwitcherFrame):
self.strategy.handle_frame(frame, direction)
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
active_service=self.strategy.active_service
)
await super().process_frame(service_switcher_filter_frame, direction)

View File

@@ -220,11 +220,6 @@ class FrameProcessor(BaseObject):
self.__process_event: Optional[asyncio.Event] = None
self.__process_frame_task: Optional[asyncio.Task] = None
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
# Then we wait for the corresponding `InterruptionFrame` to travel from
# the start of the pipeline back to the processor that sent the
# `InterruptionTaskFrame`. This wait is handled using the following
# event.
self._wait_for_interruption = False
self._wait_interruption_event = asyncio.Event()
@@ -568,17 +563,11 @@ class FrameProcessor(BaseObject):
"""Pause processing of queued frames."""
logger.trace(f"{self}: pausing frame processing")
self.__should_block_frames = True
# We should also unset the process event here, in case it was set immediately after an interruption
if self.__process_event:
self.__process_event.clear()
async def pause_processing_system_frames(self):
"""Pause processing of queued system frames."""
logger.trace(f"{self}: pausing system frame processing")
self.__should_block_system_frames = True
# We should also unset the input event here, in case it was set immediately after an interruption
if self.__input_event:
self.__input_event.clear()
async def resume_processing_frames(self):
"""Resume processing of queued frames."""
@@ -643,9 +632,7 @@ class FrameProcessor(BaseObject):
await self.__internal_push_frame(frame, direction)
# If we are waiting for an interruption and we get an interruption, then
# we can unblock `push_interruption_task_frame_and_wait()`.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
if isinstance(frame, InterruptionFrame):
self._wait_interruption_event.set()
async def push_interruption_task_frame_and_wait(self):

View File

@@ -17,6 +17,7 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
StartFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -184,13 +185,15 @@ class UserIdleProcessor(FrameProcessor):
Runs in a loop until cancelled or callback indicates completion.
"""
running = True
while running:
while True:
try:
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
except asyncio.TimeoutError:
if not self._interrupted:
self._retry_count += 1
running = await self._callback(self, self._retry_count)
should_continue = await self._callback(self, self._retry_count)
if not should_continue:
await self._stop()
break
finally:
self._idle_event.clear()

View File

@@ -70,6 +70,7 @@ import asyncio
import os
import sys
from contextlib import asynccontextmanager
from typing import Dict
from loguru import logger
@@ -182,14 +183,13 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
from pipecat.transports.smallwebrtc.request_handler import (
SmallWebRTCRequest,
SmallWebRTCRequestHandler,
)
except ImportError as e:
logger.error(f"WebRTC transport dependencies not installed: {e}")
return
# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}
# Mount the frontend
app.mount("/client", SmallWebRTCPrebuiltUI)
@@ -198,33 +198,51 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
"""Redirect root requests to client interface."""
return RedirectResponse(url="/client/")
# Initialize the SmallWebRTC request handler
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
esp32_mode=esp32_mode, host=host
)
@app.post("/api/offer")
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
async def offer(request: dict, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests and manage peer connections."""
pc_id = request.get("pc_id")
if pc_id and pc_id in pcs_map:
pipecat_connection = pcs_map[pc_id]
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request["sdp"],
type=request["type"],
restart_pc=request.get("restart_pc", False),
)
else:
pipecat_connection = SmallWebRTCConnection()
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
"""Handle WebRTC connection closure and cleanup."""
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
pcs_map.pop(webrtc_connection.pc_id, None)
# Prepare runner arguments with the callback to run your bot
async def webrtc_connection_callback(connection):
bot_module = _get_bot_module()
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
background_tasks.add_task(bot_module.bot, runner_args)
# Delegate handling to SmallWebRTCRequestHandler
answer = await small_webrtc_handler.handle_web_request(
request=request,
webrtc_connection_callback=webrtc_connection_callback,
)
answer = pipecat_connection.get_answer()
# Apply ESP32 SDP munging if enabled
if esp32_mode and host != "localhost":
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], host)
pcs_map[answer["pc_id"]] = pipecat_connection
return answer
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage FastAPI application lifecycle and cleanup connections."""
yield
await small_webrtc_handler.close()
coros = [pc.disconnect() for pc in pcs_map.values()]
await asyncio.gather(*coros)
pcs_map.clear()
app.router.lifespan_context = lifespan

View File

@@ -119,6 +119,7 @@ class AsyncAITTSService(InterruptibleTTSService):
"""
super().__init__(
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
pause_frame_processing=True,
push_stop_frames=True,
sample_rate=sample_rate,

View File

@@ -532,7 +532,9 @@ class AWSTranscribeSTTService(STTService):
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
logger.error(
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
)
break
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")

View File

@@ -13,7 +13,6 @@ supporting multiple languages, custom vocabulary, and various audio processing o
import asyncio
import base64
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Literal, Optional
import aiohttp
@@ -174,6 +173,8 @@ class _InputParamsDescriptor:
"""Descriptor for backward compatibility with deprecation warning."""
def __get__(self, obj, objtype=None):
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
@@ -207,7 +208,7 @@ class GladiaSTTService(STTService):
api_key: str,
region: Literal["us-west", "eu-west"] | None = None,
url: str = "https://api.gladia.io/v2/live",
confidence: Optional[float] = None,
confidence: float = 0.5,
sample_rate: Optional[int] = None,
model: str = "solaria-1",
params: Optional[GladiaInputParams] = None,
@@ -223,11 +224,6 @@ class GladiaSTTService(STTService):
region: Region used to process audio. eu-west or us-west. Defaults to eu-west.
url: Gladia API URL. Defaults to "https://api.gladia.io/v2/live".
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
.. deprecated:: 0.0.86
The 'confidence' parameter is deprecated and will be removed in a future version.
No confidence threshold is applied.
sample_rate: Audio sample rate in Hz. If None, uses service default.
model: Model to use for transcription. Defaults to "solaria-1".
params: Additional configuration parameters for Gladia service.
@@ -240,6 +236,7 @@ class GladiaSTTService(STTService):
params = params or GladiaInputParams()
# Warn about deprecated language parameter if it's used
if params.language is not None:
with warnings.catch_warnings():
warnings.simplefilter("always")
@@ -250,20 +247,11 @@ class GladiaSTTService(STTService):
stacklevel=2,
)
if confidence:
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"The 'confidence' parameter is deprecated and will be removed in a future version. "
"No confidence threshold is applied.",
DeprecationWarning,
stacklevel=2,
)
self._api_key = api_key
self._region = region
self._url = url
self.set_model_name(model)
self._confidence = confidence
self._params = params
self._websocket = None
self._receive_task = None
@@ -587,40 +575,43 @@ class GladiaSTTService(STTService):
elif content["type"] == "transcript":
utterance = content["data"]["utterance"]
confidence = utterance.get("confidence", 0)
language = utterance["language"]
transcript = utterance["text"]
is_final = content["data"]["is_final"]
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
if confidence >= self._confidence:
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
)
)
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
await self._handle_transcription(
transcript=transcript,
is_final=is_final,
language=language,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=content,
)
)
)
elif content["type"] == "translation":
translated_utterance = content["data"]["translated_utterance"]
original_language = content["data"]["original_language"]
translated_language = translated_utterance["language"]
confidence = translated_utterance.get("confidence", 0)
translation = translated_utterance["text"]
if translated_language != original_language:
if translated_language != original_language and confidence >= self._confidence:
await self.push_frame(
TranslationFrame(
translation, "", time_now_iso8601(), translated_language

View File

@@ -83,23 +83,14 @@ class GoogleVertexLLMService(OpenAILLMService):
self._api_key = self._get_api_token(credentials, credentials_path)
super().__init__(
api_key=self._api_key,
base_url=base_url,
model=model,
params=params,
**kwargs,
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
)
@staticmethod
def _get_base_url(params: InputParams) -> str:
"""Construct the base URL for Vertex AI API."""
# Determine the correct API host based on location
if params.location == "global":
api_host = "aiplatform.googleapis.com"
else:
api_host = f"{params.location}-aiplatform.googleapis.com"
return (
f"https://{api_host}/v1/"
f"https://{params.location}-aiplatform.googleapis.com/v1/"
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
)
@@ -127,14 +118,12 @@ class GoogleVertexLLMService(OpenAILLMService):
if credentials:
# Parse and load credentials from JSON string
creds = service_account.Credentials.from_service_account_info(
json.loads(credentials),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
elif credentials_path:
# Load credentials from JSON file
creds = service_account.Credentials.from_service_account_file(
credentials_path,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
else:
try:

View File

@@ -7,7 +7,7 @@
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
import json
from typing import Any, Dict, List, TypeAlias
from typing import Any, Dict, List, Tuple
from loguru import logger
@@ -28,8 +28,6 @@ except ModuleNotFoundError as e:
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
raise Exception(f"Missing module: {e}")
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
class MCPClient(BaseObject):
"""Client for Model Context Protocol (MCP) servers.
@@ -44,7 +42,7 @@ class MCPClient(BaseObject):
def __init__(
self,
server_params: ServerParameters,
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
**kwargs,
):
"""Initialize the MCP client with server parameters.

View File

@@ -0,0 +1,209 @@
# OpenAI Agents SDK Integration
This service integrates the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) with Pipecat, enabling powerful agentic workflows with features like:
- **Agent loops** with tool calling and response streaming
- **Handoffs** between specialized agents
- **Guardrails** for input/output validation
- **Sessions** with automatic conversation history
- **Built-in tracing** and monitoring
## Installation
Install the OpenAI Agents SDK dependency:
```bash
pip install "pipecat-ai[openai-agent]"
# or
uv add "pipecat-ai[openai-agent]"
```
## Basic Usage
```python
from pipecat.services.openai_agent import OpenAIAgentService
# Create a simple agent
agent_service = OpenAIAgentService(
name="Assistant",
instructions="You are a helpful assistant.",
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Use in a pipeline
pipeline = Pipeline([
transport.input(),
stt,
agent_service,
tts,
transport.output(),
])
```
## Features
### Tool Integration
```python
def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny, 22°C"
agent_service = OpenAIAgentService(
name="Weather Assistant",
instructions="Help users with weather information.",
tools=[get_weather],
api_key=os.getenv("OPENAI_API_KEY"),
)
```
### Agent Handoffs
```python
# Create specialized agents
weather_agent = OpenAIAgentService(
name="Weather Specialist",
instructions="Provide weather information and forecasts.",
tools=[get_weather, get_forecast],
)
trivia_agent = OpenAIAgentService(
name="Trivia Master",
instructions="Share interesting facts and trivia.",
tools=[get_random_fact],
)
# Create coordinator that can hand off to specialists
coordinator = OpenAIAgentService(
name="Coordinator",
instructions="Route users to the right specialist.",
handoffs=[weather_agent.agent, trivia_agent.agent],
)
```
### Guardrails
```python
from agents import InputGuardrail, GuardrailFunctionOutput
async def content_filter(ctx, agent, input_data):
# Check input for appropriate content
if is_inappropriate(input_data):
return GuardrailFunctionOutput(
tripwire_triggered=True,
output_info="Content not allowed"
)
return GuardrailFunctionOutput(tripwire_triggered=False)
agent_service = OpenAIAgentService(
name="Safe Assistant",
instructions="You are a helpful and safe assistant.",
input_guardrails=[InputGuardrail(guardrail_function=content_filter)],
)
```
### Session Management
```python
agent_service = OpenAIAgentService(
name="Personal Assistant",
instructions="Remember user preferences and context.",
session_config={
"user_id": "user_123",
"memory_enabled": True,
}
)
# Update session context dynamically
agent_service.update_session_context({
"user_preferences": {"language": "en", "style": "formal"}
})
```
## Configuration Options
### Basic Parameters
- `name`: Agent identifier for handoffs and tracing
- `instructions`: System prompt defining agent behavior
- `api_key`: OpenAI API key (or use `OPENAI_API_KEY` env var)
- `streaming`: Enable real-time token streaming (default: True)
### Advanced Configuration
- `tools`: List of callable functions for the agent to use
- `handoffs`: List of other agents this agent can transfer to
- `input_guardrails`: Input validation and filtering
- `output_guardrails`: Output validation and filtering
- `model_config`: Model settings (model, temperature, etc.)
- `session_config`: Session and memory configuration
### Model Configuration
```python
agent_service = OpenAIAgentService(
name="Precise Assistant",
instructions="Provide accurate, concise responses.",
model_config={
"model": "gpt-4o",
"temperature": 0.1,
"max_tokens": 150,
}
)
```
## Examples
See the foundational examples:
- [`45-openai-agent-basic.py`](../examples/foundational/45-openai-agent-basic.py) - Basic agent with tools
- [`46-openai-agent-handoffs.py`](../examples/foundational/46-openai-agent-handoffs.py) - Multi-agent system with handoffs
## Methods
### Core Methods
- `update_agent_config()` - Update instructions and model settings
- `add_tool()` - Add new tools dynamically
- `add_handoff_agent()` - Add handoff destinations
- `get_session_context()` - Get current session state
- `update_session_context()` - Update session variables
### Lifecycle Methods
Inherited from `AIService`:
- `start()` - Initialize the agent
- `stop()` - Clean up resources
- `cancel()` - Cancel ongoing operations
## Integration with Pipecat
The service processes `TextFrame` inputs and generates:
- `LLMFullResponseStartFrame` - Response beginning
- `LLMTextFrame` - Streaming text tokens (if streaming enabled)
- `LLMFullResponseEndFrame` - Response completion
This integrates seamlessly with Pipecat's conversation pipeline and context aggregators.
## Error Handling
The service includes robust error handling for:
- Missing API keys or SDK installation
- Agent processing failures
- Network connectivity issues
- Malformed tool responses
Errors are emitted as `ErrorFrame` objects in the pipeline.
## Requirements
- OpenAI API key
- `openai-agents` package
- Python 3.10+
## Limitations
- Currently supports OpenAI models only (via Agents SDK)
- Handoffs work within individual requests (no cross-request state)
- Real-time voice features require additional setup

View File

@@ -0,0 +1,11 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI Agents SDK service for Pipecat integration."""
from .agent_service import OpenAIAgentService
__all__ = ["OpenAIAgentService"]

View File

@@ -0,0 +1,567 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI Agents SDK integration service.
Provides integration with the OpenAI Agents SDK for building AI applications
within Pipecat pipelines. This service allows leveraging agent loops, handoffs,
guardrails, sessions, and tools from the OpenAI Agents SDK.
"""
import asyncio
import os
from dataclasses import dataclass
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Protocol,
Sequence,
Union,
override,
runtime_checkable,
)
from loguru import logger
try:
from agents import Agent, InputGuardrail, OutputGuardrail, Runner, Tool
from agents.result import RunResult, RunResultStreaming
from agents.stream_events import StreamEvent
except ImportError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use OpenAI Agents SDK, you need to `pip install openai-agents`. "
"Also, set `OPENAI_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
TextFrame,
UserImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
@runtime_checkable
class ToolLike(Protocol):
"""Protocol for tool-like objects."""
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Tool call interface."""
...
@runtime_checkable
class AgentLike(Protocol):
"""Protocol for agent-like objects."""
name: str
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Agent call interface."""
...
@dataclass
class OpenAIAgentContextAggregatorPair:
"""Pair of OpenAI Agent context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: "OpenAIAgentUserContextAggregator"
_assistant: "OpenAIAgentAssistantContextAggregator"
def user(self) -> "OpenAIAgentUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> "OpenAIAgentAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class OpenAIAgentService(AIService):
"""OpenAI Agents SDK service for Pipecat.
Integrates the OpenAI Agents SDK with Pipecat's pipeline architecture,
enabling advanced agentic workflows with features like handoffs, guardrails,
sessions, and tools within real-time conversational AI applications.
The service processes text input frames and generates streaming responses
using the agent's configured capabilities.
"""
def __init__(
self,
*,
agent: Optional[Agent] = None,
name: str = "Assistant",
instructions: Union[str, Sequence[str]] = "You are a helpful assistant.",
handoffs: Optional[Sequence[AgentLike]] = None,
tools: Optional[Sequence[ToolLike]] = None,
input_guardrails: Optional[Sequence[InputGuardrail]] = None,
output_guardrails: Optional[Sequence[OutputGuardrail]] = None,
model_config: Optional[Dict[str, Any]] = None,
session_config: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
streaming: bool = True,
**kwargs,
):
"""Initialize the OpenAI Agent service.
Args:
agent: Pre-configured Agent instance. If provided, other agent configuration
parameters will be ignored.
name: Name of the agent for identification and handoffs.
instructions: System instructions that define the agent's behavior.
handoffs: List of other agents this agent can hand off to.
tools: List of callable functions the agent can use as tools.
input_guardrails: List of input validation guardrails.
output_guardrails: List of output validation guardrails.
model_config: Configuration for the underlying language model.
session_config: Configuration for session management.
api_key: OpenAI API key. If not provided, will use OPENAI_API_KEY env var.
streaming: Whether to use streaming responses for real-time output.
**kwargs: Additional arguments passed to the parent AIService.
"""
super().__init__(**kwargs)
# Set up API key
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
elif not os.getenv("OPENAI_API_KEY"):
logger.warning("No OpenAI API key provided. Set OPENAI_API_KEY environment variable.")
# Create or use existing agent
if agent:
self._agent = agent
else:
# Convert sequences to lists and handle string instructions
agent_handoffs: List[Any] = list(handoffs) if handoffs else []
agent_tools: List[Any] = list(tools) if tools else []
agent_input_guardrails: List[Any] = list(input_guardrails) if input_guardrails else []
agent_output_guardrails: List[Any] = (
list(output_guardrails) if output_guardrails else []
)
# Handle instructions - convert sequence to string if needed
if isinstance(instructions, str):
agent_instructions = instructions
else:
agent_instructions = " ".join(str(instr) for instr in instructions)
self._agent = Agent(
name=name,
instructions=agent_instructions,
handoffs=agent_handoffs,
tools=agent_tools,
input_guardrails=agent_input_guardrails,
output_guardrails=agent_output_guardrails,
model=model_config.get("model", "gpt-4o") if model_config else "gpt-4o",
)
self._streaming = streaming
self._session_config = session_config or {}
self._current_session = None
self._accumulated_text = ""
# Set model name for metrics
if model_config and "model" in model_config:
self.set_model_name(model_config["model"])
else:
self.set_model_name("gpt-4o") # Default model
logger.info(f"Initialized OpenAI Agent service: {self._agent.name}")
@property
def agent(self) -> Agent:
"""Get the underlying OpenAI Agent.
Returns:
The configured Agent instance.
"""
return self._agent
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIAgentContextAggregatorPair:
"""Create OpenAI-specific context aggregators for agent interactions.
Creates a pair of context aggregators optimized for OpenAI Agent interactions,
including support for function calls, tool usage, and conversation management.
Args:
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
OpenAIAgentContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIAgentContextAggregatorPair.
"""
user = OpenAIAgentUserContextAggregator(context, params=user_params)
assistant = OpenAIAgentAssistantContextAggregator(context, params=assistant_params)
return OpenAIAgentContextAggregatorPair(_user=user, _assistant=assistant)
def update_agent_config(
self,
*,
instructions: Optional[str] = None,
model_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
"""Update agent configuration dynamically.
Args:
instructions: New system instructions for the agent.
model_config: Updated model configuration.
**kwargs: Additional agent configuration parameters.
"""
if instructions:
self._agent.instructions = instructions
logger.info(f"Updated agent instructions for {self._agent.name}")
if model_config:
# Note: OpenAI Agents SDK handles model configuration during agent creation
# We can't update model_config after agent is created, but we can update our model name
if "model" in model_config:
self.set_model_name(model_config["model"])
logger.info(f"Updated model config for {self._agent.name}")
async def start(self, frame: StartFrame):
"""Start the OpenAI Agent service.
Initializes the agent session and prepares for processing.
Args:
frame: The start frame containing initialization parameters.
"""
logger.info(f"Starting OpenAI Agent service: {self._agent.name}")
await super().start(frame)
async def stop(self, frame: EndFrame):
"""Stop the OpenAI Agent service.
Cleans up resources and ends the current session.
Args:
frame: The end frame.
"""
logger.info(f"Stopping OpenAI Agent service: {self._agent.name}")
await super().stop(frame)
async def cancel(self, frame: CancelFrame):
"""Cancel the OpenAI Agent service.
Cancels any ongoing operations.
Args:
frame: The cancel frame.
"""
logger.info(f"Cancelling OpenAI Agent service: {self._agent.name}")
await super().cancel(frame)
@override
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Process frames and handle agent interactions.
Processes OpenAILLMContextFrame and TextFrame by running them through the OpenAI Agent
and streams the results back as LLM frames.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
# Process context frame through the agent
try:
await self.push_frame(LLMFullResponseStartFrame())
# Extract the latest user message from the context
messages = frame.context.get_messages()
if messages:
# Get the last user message
for message in reversed(messages):
if message.get("role") == "user":
content = message.get("content", "")
if isinstance(content, list):
# Extract text from content array
text_parts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text_parts.append(part.get("text", ""))
user_input = " ".join(text_parts)
else:
user_input = str(content)
if user_input.strip():
await self._process_agent_request(user_input)
break
await self.push_frame(LLMFullResponseEndFrame())
except Exception as e:
logger.error(f"Error processing agent context: {e}")
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
elif isinstance(frame, TextFrame):
# Process text input through the agent directly (for backwards compatibility)
try:
await self.push_frame(LLMFullResponseStartFrame())
await self._process_agent_request(frame.text)
await self.push_frame(LLMFullResponseEndFrame())
except Exception as e:
logger.error(f"Error processing agent request: {e}")
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
else:
# For frames we don't handle, pass them through with direction
await self.push_frame(frame, direction)
async def _process_agent_request(self, input_text: str):
"""Process an agent request and stream the results.
Args:
input_text: The user input text to process.
"""
logger.debug(f"Processing agent request: {input_text}")
if self._streaming:
await self._process_streaming_response(input_text)
else:
await self._process_non_streaming_response(input_text)
async def _process_streaming_response(self, input_text: str):
"""Process a streaming agent response.
Args:
input_text: The user input text to process.
"""
try:
# Run the agent with streaming
result: RunResultStreaming = Runner.run_streamed(
self._agent, input_text, context=self._session_config
)
has_streaming_deltas = False
# Process the stream events
async for event in result.stream_events():
if event.type == "raw_response_event":
# Handle token-by-token streaming
# Only check for delta on events that are known to have it
if hasattr(event.data, "delta") and getattr(event.data, "delta", None):
delta_text = getattr(event.data, "delta", "")
if delta_text:
has_streaming_deltas = True
self._accumulated_text += delta_text
await self.push_frame(LLMTextFrame(text=delta_text))
elif event.type == "run_item_stream_event":
# Handle completed items
if event.item.type == "message_output_item":
# Only process complete message if we didn't get streaming deltas
if not has_streaming_deltas:
message_text = self._extract_message_text(event.item)
logger.debug(
f"Processing complete message (no deltas): {message_text[:50]}..."
if len(message_text) > 50
else f"Processing complete message: {message_text}"
)
if message_text:
await self.push_frame(LLMTextFrame(text=message_text))
elif event.item.type == "tool_call_item":
# Use getattr for safe attribute access
tool_name = getattr(event.item, "tool_name", "unknown")
logger.debug(f"Tool called: {tool_name}")
elif event.item.type == "tool_call_output_item":
output = getattr(event.item, "output", "no output")
logger.debug(f"Tool output: {output}")
elif event.type == "agent_updated_stream_event":
logger.debug(f"Agent updated: {event.new_agent.name}")
# Reset accumulated text for next request
self._accumulated_text = ""
except Exception as e:
logger.error(f"Error in streaming response: {e}")
raise
async def _process_non_streaming_response(self, input_text: str):
"""Process a non-streaming agent response.
Args:
input_text: The user input text to process.
"""
try:
# Run the agent without streaming
result: RunResult = await Runner.run(
self._agent, input_text, context=self._session_config
)
# Send the final output
if result.final_output:
await self.push_frame(LLMTextFrame(text=result.final_output))
except Exception as e:
logger.error(f"Error in non-streaming response: {e}")
raise
def _extract_message_text(self, item) -> str:
"""Extract text from a message output item.
Args:
item: The message output item from the agent.
Returns:
The extracted text content.
"""
try:
# Handle OpenAI Agents SDK MessageOutputItem format
if hasattr(item, "raw_item") and hasattr(item.raw_item, "content"):
content = item.raw_item.content
if isinstance(content, list):
text_parts = []
for content_part in content:
if hasattr(content_part, "text"):
text_parts.append(content_part.text)
elif (
isinstance(content_part, dict)
and content_part.get("type") == "output_text"
):
text_parts.append(content_part.get("text", ""))
elif isinstance(content_part, dict) and content_part.get("type") == "text":
text_parts.append(content_part.get("text", ""))
return "".join(text_parts)
elif isinstance(content, str):
return content
# Handle direct content attribute
elif hasattr(item, "content"):
if isinstance(item.content, str):
return item.content
elif isinstance(item.content, list):
# Extract text from content array
text_parts = []
for content_part in item.content:
if isinstance(content_part, dict) and content_part.get("type") == "text":
text_parts.append(content_part.get("text", ""))
elif isinstance(content_part, str):
text_parts.append(content_part)
return "".join(text_parts)
# If no text content found, return empty string instead of str(item)
logger.debug(f"No extractable text content found in item: {type(item)}")
return ""
except Exception as e:
logger.warning(f"Could not extract text from message item: {e}")
return ""
async def add_tool(self, tool_function: ToolLike):
"""Add a tool function to the agent.
Args:
tool_function: A callable function or Tool object to add as a tool.
"""
if hasattr(self._agent, "tools"):
# Cast to Any to handle the type variance issue
tools_list: List[Any] = self._agent.tools
tools_list.append(tool_function)
tool_name = getattr(
tool_function, "__name__", getattr(tool_function, "name", "unknown")
)
logger.info(f"Added tool {tool_name} to agent {self._agent.name}")
async def add_handoff_agent(self, agent: AgentLike):
"""Add a handoff agent.
Args:
agent: Another Agent instance or handoff object that this agent can hand off to.
"""
if hasattr(self._agent, "handoffs"):
# Cast to Any to handle the type variance issue
handoffs_list: List[Any] = self._agent.handoffs
handoffs_list.append(agent)
agent_name = getattr(agent, "name", "unknown")
logger.info(f"Added handoff agent {agent_name} to agent {self._agent.name}")
def get_session_context(self) -> Dict[str, Any]:
"""Get the current session context.
Returns:
Dictionary containing the current session context.
"""
return self._session_config.copy()
def update_session_context(self, context: Dict[str, Any]):
"""Update the session context.
Args:
context: Dictionary of context updates to apply.
"""
self._session_config.update(context)
logger.debug(f"Updated session context for agent {self._agent.name}")
class OpenAIAgentUserContextAggregator(LLMUserContextAggregator):
"""OpenAI Agent-specific user context aggregator.
Handles aggregation of user messages for OpenAI Agent services.
Inherits all functionality from the base LLMUserContextAggregator.
"""
pass
class OpenAIAgentAssistantContextAggregator(LLMAssistantContextAggregator):
"""OpenAI Agent-specific assistant context aggregator.
Handles aggregation of assistant messages for OpenAI Agent services,
with specialized support for OpenAI's function calling format,
tool usage tracking, and agent interaction management.
"""
pass

View File

@@ -25,7 +25,6 @@ from pydantic import BaseModel
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
from pipecat.frames.frames import (
CancelFrame,
ControlFrame,
EndFrame,
ErrorFrame,
Frame,
@@ -42,7 +41,6 @@ from pipecat.frames.frames import (
UserAudioRawFrame,
UserImageRawFrame,
UserImageRequestFrame,
DataFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
from pipecat.transcriptions.language import Language
@@ -107,17 +105,6 @@ class DailyInputTransportMessageUrgentFrame(InputTransportMessageUrgentFrame):
participant_id: Optional[str] = None
@dataclass
class DailyUpdateRemoteParticipantsFrame(ControlFrame):
"""Frame to update remote participants in Daily calls.
Parameters:
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
"""
remote_participants: Mapping[str, Any] = None
class WebRTCVADAnalyzer(VADAnalyzer):
"""Voice Activity Detection analyzer using WebRTC.
@@ -228,7 +215,6 @@ class DailyCallbacks(BaseModel):
on_active_speaker_changed: Called when the active speaker of the call has changed.
on_joined: Called when bot successfully joined a room.
on_left: Called when bot left a room.
on_before_leave: Called when bot is about to leave the room.
on_error: Called when an error occurs.
on_app_message: Called when receiving an app message.
on_call_state_updated: Called when call state changes.
@@ -258,7 +244,6 @@ class DailyCallbacks(BaseModel):
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
on_left: Callable[[], Awaitable[None]]
on_before_leave: Callable[[], Awaitable[None]]
on_error: Callable[[str], Awaitable[None]]
on_app_message: Callable[[Any, str], Awaitable[None]]
on_call_state_updated: Callable[[str], Awaitable[None]]
@@ -374,7 +359,6 @@ class DailyTransportClient(EventHandler):
self._transcription_ids = []
self._transcription_status = None
self._dial_out_session_id: str = ""
self._dial_in_session_id: str = ""
self._joining = False
self._joined = False
@@ -735,9 +719,6 @@ class DailyTransportClient(EventHandler):
logger.info(f"Leaving {self._room_url}")
# Call callback before leaving.
await self._callbacks.on_before_leave()
if self._params.transcription_enabled:
await self.stop_transcription()
@@ -842,16 +823,6 @@ class DailyTransportClient(EventHandler):
Args:
settings: SIP call transfer settings.
"""
session_id = (
settings.get("sessionId") or self._dial_out_session_id or self._dial_in_session_id
)
if not session_id:
logger.error("Unable to transfer SIP call: 'sessionId' is not set")
return
# Update 'sessionId' field.
settings["sessionId"] = session_id
future = self._get_event_loop().create_future()
self._client.sip_call_transfer(settings, completion=completion_callback(future))
await future
@@ -1170,7 +1141,6 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in connection data.
"""
self._dial_in_session_id = data["sessionId"] if "sessionId" in data else ""
self._call_event_callback(self._callbacks.on_dialin_connected, data)
def on_dialin_ready(self, sip_endpoint: str):
@@ -1187,9 +1157,6 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in stop data.
"""
# Cleanup only if our session stopped.
if data.get("sessionId") == self._dial_in_session_id:
self._dial_in_session_id = ""
self._call_event_callback(self._callbacks.on_dialin_stopped, data)
def on_dialin_error(self, data: Any):
@@ -1198,9 +1165,6 @@ class DailyTransportClient(EventHandler):
Args:
data: Dial-in error data.
"""
# Cleanup only if our session errored out.
if data.get("sessionId") == self._dial_in_session_id:
self._dial_in_session_id = ""
self._call_event_callback(self._callbacks.on_dialin_error, data)
def on_dialin_warning(self, data: Any):
@@ -1235,7 +1199,7 @@ class DailyTransportClient(EventHandler):
data: Dial-out stop data.
"""
# Cleanup only if our session stopped.
if data.get("sessionId") == self._dial_out_session_id:
if data["sessionId"] == self._dial_out_session_id:
self._dial_out_session_id = ""
self._call_event_callback(self._callbacks.on_dialout_stopped, data)
@@ -1246,7 +1210,7 @@ class DailyTransportClient(EventHandler):
data: Dial-out error data.
"""
# Cleanup only if our session errored out.
if data.get("sessionId") == self._dial_out_session_id:
if data["sessionId"] == self._dial_out_session_id:
self._dial_out_session_id = ""
self._call_event_callback(self._callbacks.on_dialout_error, data)
@@ -1803,31 +1767,6 @@ class DailyOutputTransport(BaseOutputTransport):
# Leave the room.
await self._client.leave()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process outgoing frames, including transport messages.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
logger.debug(f"Got a DailyUpdateRemoteParticipantsFrame: {frame}")
await self._client.update_remote_participants(frame.remote_participants)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process outgoing frames, including transport messages.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
await self._client.update_remote_participants(frame.remote_participants)
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
"""Send a transport message to participants.
@@ -1923,7 +1862,6 @@ class DailyTransport(BaseTransport):
on_active_speaker_changed=self._on_active_speaker_changed,
on_joined=self._on_joined,
on_left=self._on_left,
on_before_leave=self._on_before_leave,
on_error=self._on_error,
on_app_message=self._on_app_message,
on_call_state_updated=self._on_call_state_updated,
@@ -1987,10 +1925,6 @@ class DailyTransport(BaseTransport):
self._register_event_handler("on_recording_started")
self._register_event_handler("on_recording_stopped")
self._register_event_handler("on_recording_error")
self._register_event_handler("on_before_disconnect", sync=True)
# Deprecated
self._register_event_handler("on_joined")
self._register_event_handler("on_left")
#
# BaseTransport
@@ -2242,10 +2176,6 @@ class DailyTransport(BaseTransport):
"""Handle room left events."""
await self._call_event_handler("on_left")
async def _on_before_leave(self):
"""Handle before leave room events."""
await self._call_event_handler("on_before_disconnect")
async def _on_error(self, error):
"""Handle error events and push error frames."""
await self._call_event_handler("on_error", error)
@@ -2385,7 +2315,7 @@ class DailyTransport(BaseTransport):
"""Handle participant updated events."""
await self._call_event_handler("on_participant_updated", participant)
async def _on_transcription_message(self, message: Mapping[str, Any]) -> None:
async def _on_transcription_message(self, message: Dict[str, Any]) -> None:
"""Handle transcription message events."""
await self._call_event_handler("on_transcription_message", message)

View File

@@ -114,7 +114,6 @@ class LiveKitCallbacks(BaseModel):
on_connected: Callable[[], Awaitable[None]]
on_disconnected: Callable[[], Awaitable[None]]
on_before_disconnect: Callable[[], Awaitable[None]]
on_participant_connected: Callable[[str], Awaitable[None]]
on_participant_disconnected: Callable[[str], Awaitable[None]]
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
@@ -283,7 +282,6 @@ class LiveKitTransportClient:
return
logger.info(f"Disconnecting from {self._room_name}")
await self._callbacks.on_before_disconnect()
await self.room.disconnect()
self._connected = False
logger.info(f"Disconnected from {self._room_name}")
@@ -920,7 +918,6 @@ class LiveKitTransport(BaseTransport):
callbacks = LiveKitCallbacks(
on_connected=self._on_connected,
on_disconnected=self._on_disconnected,
on_before_disconnect=self._on_before_disconnect,
on_participant_connected=self._on_participant_connected,
on_participant_disconnected=self._on_participant_disconnected,
on_audio_track_subscribed=self._on_audio_track_subscribed,
@@ -950,7 +947,6 @@ class LiveKitTransport(BaseTransport):
self._register_event_handler("on_first_participant_joined")
self._register_event_handler("on_participant_left")
self._register_event_handler("on_call_state_updated")
self._register_event_handler("on_before_disconnect", sync=True)
def input(self) -> LiveKitInputTransport:
"""Get the input transport for receiving media and events.
@@ -1045,10 +1041,6 @@ class LiveKitTransport(BaseTransport):
"""Handle room disconnected events."""
await self._call_event_handler("on_disconnected")
async def _on_before_disconnect(self):
"""Handle before disconnection room events."""
await self._call_event_handler("on_before_disconnect")
async def _on_participant_connected(self, participant_id: str):
"""Handle participant connected events."""
await self._call_event_handler("on_participant_connected", participant_id)

View File

@@ -95,20 +95,15 @@ class SmallWebRTCTrack:
enable/disable control and frame discarding for audio and video streams.
"""
def __init__(self, receiver):
def __init__(self, track: MediaStreamTrack):
"""Initialize the WebRTC track wrapper.
Args:
receiver: The RemoteStreamTrack receiver instance.
track: The underlying MediaStreamTrack to wrap.
index: The index of the track in the transceiver (0 for mic, 1 for cam, 2 for screen)
"""
self._receiver = receiver
# Configuring the receiver for not consuming the track by default to prevent memory grow
self._receiver._enabled = False
self._track = receiver.track
self._track = track
self._enabled = True
self._last_recv_time: float = 0.0
self._idle_task: Optional[asyncio.Task] = None
self._idle_timeout: float = 2.0 # seconds before discarding old frames
def set_enabled(self, enabled: bool) -> None:
"""Enable or disable the track.
@@ -143,44 +138,13 @@ class SmallWebRTCTrack:
async def recv(self) -> Optional[Frame]:
"""Receive the next frame from the track.
Enables the internal receiving state and starts idle watcher.
Returns:
The next frame, except for video tracks, where it returns the frame only if the track is enabled, otherwise, returns None.
"""
self._receiver._enabled = True
self._last_recv_time = time.time()
# start idle watcher if not already running
if not self._idle_task or self._idle_task.done():
self._idle_task = asyncio.create_task(self._idle_watcher())
if not self._enabled and self._track.kind == "video":
return None
return await self._track.recv()
async def _idle_watcher(self):
"""Disable receiving if idle for more than _idle_timeout and monitor queue size."""
while self._receiver._enabled:
await asyncio.sleep(self._idle_timeout)
idle_duration = time.time() - self._last_recv_time
if idle_duration >= self._idle_timeout:
# discard old frames to prevent memory growth
logger.debug(
f"Disabling receiver for {self._track.kind} track after {idle_duration:.2f}s idle"
)
await self.discard_old_frames()
self._receiver._enabled = False
def stop(self):
"""Stop receiving frames from the track."""
self._receiver._enabled = False
if self._idle_task:
self._idle_task.cancel()
self._idle_task = None
if self._track:
self._track.stop()
def __getattr__(self, name):
"""Forward attribute access to the underlying track.
@@ -490,10 +454,6 @@ class SmallWebRTCConnection(BaseObject):
async def _close(self):
"""Close the peer connection and cleanup resources."""
for track in self._track_map.values():
if track:
track.stop()
self._track_map.clear()
if self._pc:
await self._pc.close()
self._message_queue.clear()
@@ -566,8 +526,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No audio transceiver is available")
return None
receiver = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver
audio_track = SmallWebRTCTrack(receiver) if receiver else None
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
audio_track = SmallWebRTCTrack(track) if track else None
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
return audio_track
@@ -588,8 +548,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No video transceiver is available")
return None
receiver = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver
video_track = SmallWebRTCTrack(receiver) if receiver else None
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track
@@ -610,8 +570,8 @@ class SmallWebRTCConnection(BaseObject):
logger.warning("No screen video transceiver is available")
return None
receiver = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver
video_track = SmallWebRTCTrack(receiver) if receiver else None
track = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
self._track_map[SCREEN_VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track

View File

@@ -1,200 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""SmallWebRTC request handler for managing peer connections.
This module provides a client for handling web requests and managing WebRTC connections.
"""
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Optional
from fastapi import HTTPException
from loguru import logger
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
@dataclass
class SmallWebRTCRequest:
"""Small WebRTC transport session arguments for the runner.
Parameters:
sdp: The SDP string (Session Description Protocol).
type: The type of the SDP, either "offer" or "answer".
pc_id: Optional identifier for the peer connection.
restart_pc: Optional whether to restart the peer connection.
request_data: Optional custom data sent by the customer.
"""
sdp: str
type: str
pc_id: Optional[str] = None
restart_pc: Optional[bool] = None
request_data: Optional[Any] = None
class ConnectionMode(Enum):
"""Enum defining the connection handling modes."""
SINGLE = "single" # Only one active connection allowed
MULTIPLE = "multiple" # Multiple simultaneous connections allowed
class SmallWebRTCRequestHandler:
"""SmallWebRTC request handler for managing peer connections.
This class is responsible for:
- Handling incoming SmallWebRTC requests.
- Creating and managing WebRTC peer connections.
- Supporting ESP32-specific SDP munging if enabled.
- Invoking callbacks for newly initialized connections.
- Supporting both single and multiple connection modes.
"""
def __init__(
self,
ice_servers: Optional[List[IceServer]] = None,
esp32_mode: bool = False,
host: Optional[str] = None,
connection_mode: ConnectionMode = ConnectionMode.MULTIPLE,
) -> None:
"""Initialize a SmallWebRTC request handler.
Args:
ice_servers (Optional[List[IceServer]]): List of ICE servers to use for WebRTC
connections.
esp32_mode (bool): If True, enables ESP32-specific SDP munging.
host (Optional[str]): Host address used for SDP munging in ESP32 mode.
Ignored if `esp32_mode` is False.
connection_mode (ConnectionMode): Mode of operation for handling connections.
SINGLE allows only one active connection, MULTIPLE allows several.
"""
self._ice_servers = ice_servers
self._esp32_mode = esp32_mode
self._host = host
self._connection_mode = connection_mode
# Store connections by pc_id
self._pcs_map: Dict[str, SmallWebRTCConnection] = {}
def _check_single_connection_constraints(self, pc_id: Optional[str]) -> None:
"""Check if the connection request satisfies single connection mode constraints.
Args:
pc_id: The peer connection ID from the request
Raises:
HTTPException: If constraints are violated in single connection mode
"""
if self._connection_mode != ConnectionMode.SINGLE:
return
if not self._pcs_map: # No existing connections
return
# Get the existing connection (should be only one in single mode)
existing_connection = next(iter(self._pcs_map.values()))
if existing_connection.pc_id != pc_id and pc_id:
logger.warning(
f"Connection pc_id mismatch: existing={existing_connection.pc_id}, received={pc_id}"
)
raise HTTPException(status_code=400, detail="PC ID mismatch with existing connection")
if not pc_id:
logger.warning(
"Cannot create new connection: existing connection found but no pc_id received"
)
raise HTTPException(
status_code=400,
detail="Cannot create new connection with existing connection active",
)
async def handle_web_request(
self,
request: SmallWebRTCRequest,
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
) -> None:
"""Handle a SmallWebRTC request and resolve the pending answer.
This method will:
- Reuse an existing WebRTC connection if `pc_id` exists.
- Otherwise, create a new `SmallWebRTCConnection`.
- Invoke the provided callback with the connection.
- Manage ESP32-specific munging if enabled.
- Enforce single/multiple connection mode constraints.
Args:
request (SmallWebRTCRequest): The incoming WebRTC request, containing
SDP, type, and optionally a `pc_id`.
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
asynchronous callback function that is invoked with the WebRTC connection.
Raises:
HTTPException: If connection mode constraints are violated
Exception: Any exception raised during request handling or callback execution
will be logged and propagated.
"""
try:
pc_id = request.pc_id
# Check connection mode constraints first
self._check_single_connection_constraints(pc_id)
# After constraints are satisfied, get the existing connection if any
existing_connection = self._pcs_map.get(pc_id) if pc_id else None
if existing_connection:
pipecat_connection = existing_connection
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request.sdp,
type=request.type,
restart_pc=request.restart_pc or False,
)
else:
pipecat_connection = SmallWebRTCConnection(ice_servers=self._ice_servers)
await pipecat_connection.initialize(sdp=request.sdp, type=request.type)
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
self._pcs_map.pop(webrtc_connection.pc_id, None)
# Invoke callback provided in runner arguments
try:
await webrtc_connection_callback(pipecat_connection)
logger.debug(
f"webrtc_connection_callback executed successfully for peer: {pipecat_connection.pc_id}"
)
except Exception as callback_error:
logger.error(
f"webrtc_connection_callback failed for peer {pipecat_connection.pc_id}: {callback_error}"
)
answer = pipecat_connection.get_answer()
if self._esp32_mode and self._host and self._host != "localhost":
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], self._host)
self._pcs_map[answer["pc_id"]] = pipecat_connection
return answer
except Exception as e:
logger.error(f"Error processing SmallWebRTC request: {e}")
logger.debug(f"SmallWebRTC request details: {request}")
raise
async def close(self):
"""Clear the connection map."""
coros = [pc.disconnect() for pc in self._pcs_map.values()]
await asyncio.gather(*coros)
self._pcs_map.clear()

View File

@@ -14,33 +14,13 @@ and async cleanup for all Pipecat components.
import asyncio
import inspect
from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Optional
from loguru import logger
from pipecat.utils.utils import obj_count, obj_id
@dataclass
class EventHandler:
"""Data class to store event handlers information.
This data class stores the event name, a list of handlers to run for this
event, and whether these handlers will be executed in a task.
Attributes:
name (str): The name of the event handler.
handlers (List[Any]): A list of functions to be called when this event is triggered.
is_sync (bool): Indicates whether the functions are executed in a task.
"""
name: str
handlers: List[Any]
is_sync: bool
class BaseObject(ABC):
"""Abstract base class providing common functionality for Pipecat objects.
@@ -61,7 +41,7 @@ class BaseObject(ABC):
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
# Registered event handlers.
self._event_handlers: Dict[str, EventHandler] = {}
self._event_handlers: dict = {}
# Set of tasks being executed. When a task finishes running it gets
# automatically removed from the set. When we cleanup we wait for all
@@ -123,21 +103,18 @@ class BaseObject(ABC):
Can be sync or async.
"""
if event_name in self._event_handlers:
self._event_handlers[event_name].handlers.append(handler)
self._event_handlers[event_name].append(handler)
else:
logger.warning(f"Event handler {event_name} not registered")
def _register_event_handler(self, event_name: str, sync: bool = False):
def _register_event_handler(self, event_name: str):
"""Register an event handler type.
Args:
event_name: The name of the event type to register.
sync: Whether this event handler will be executed in a task.
"""
if event_name not in self._event_handlers:
self._event_handlers[event_name] = EventHandler(
name=event_name, handlers=[], is_sync=sync
)
self._event_handlers[event_name] = []
else:
logger.warning(f"Event handler {event_name} not registered")
@@ -149,43 +126,34 @@ class BaseObject(ABC):
*args: Positional arguments to pass to event handlers.
**kwargs: Keyword arguments to pass to event handlers.
"""
if event_name not in self._event_handlers:
# If we haven't registered an event handler, we don't need to do
# anything.
if not self._event_handlers.get(event_name):
return
event_handler = self._event_handlers[event_name]
# Create the task.
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
for handler in event_handler.handlers:
if event_handler.is_sync:
# Just run the handler.
await self._run_handler(event_handler.name, handler, *args, **kwargs)
else:
# Create the task. Note that this is a task per each function
# handler. Users can register to an event handler multiple
# times.
task = asyncio.create_task(
self._run_handler(event_handler.name, handler, *args, **kwargs)
)
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
# Add it to our list of event tasks.
self._event_tasks.add((event_name, task))
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
# Remove the task from the event tasks list when the task completes.
task.add_done_callback(self._event_task_finished)
async def _run_handler(self, event_name: str, handler, *args, **kwargs):
async def _run_task(self, event_name: str, *args, **kwargs):
"""Execute all handlers for an event.
Args:
event_name: The event name for this handler.
handler: The handler function to run.
event_name: The name of the event being handled.
*args: Positional arguments to pass to handlers.
**kwargs: Keyword arguments to pass to handlers.
"""
try:
if inspect.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
handler(self, *args, **kwargs)
for handler in self._event_handlers[event_name]:
if inspect.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
handler(self, *args, **kwargs)
except Exception as e:
logger.exception(f"Exception in event handler {event_name}: {e}")

172
test_openai_agent.py Normal file
View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""Simple test script for OpenAI Agent service."""
import asyncio
import os
from unittest.mock import MagicMock, patch
# Mock the OpenAI API key for testing
os.environ["OPENAI_API_KEY"] = "test-key-for-testing"
from pipecat.frames.frames import TextFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai_agent import OpenAIAgentService
async def test_basic_functionality():
"""Test basic OpenAI Agent service functionality."""
print("🧪 Testing OpenAI Agent Service...")
# Create a simple weather tool for testing
def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"The weather in {location} is sunny and 22°C."
try:
# Create the service
print("📋 Creating OpenAI Agent service...")
service = OpenAIAgentService(
name="Test Assistant",
instructions="You are a helpful test assistant.",
tools=[get_weather],
api_key="test-key",
streaming=True,
)
print(f"✅ Service created successfully!")
print(f" - Agent name: {service.agent.name}")
print(f" - Model name: {service.model_name}")
print(f" - Streaming enabled: {service._streaming}")
# Test basic configuration
print("⚙️ Testing configuration updates...")
service.update_agent_config(
instructions="Updated test instructions",
model_config={"model": "gpt-4o", "temperature": 0.5},
)
print(f"✅ Configuration updated!")
print(f" - New instructions: {service.agent.instructions}")
print(f" - New model: {service.model_name}")
# Test session context
print("💾 Testing session context...")
service.update_session_context({"user_id": "test-user", "session": "test-session"})
context = service.get_session_context()
print(f"✅ Session context managed!")
print(f" - Context keys: {list(context.keys())}")
# Test adding tools
print("🔧 Testing tool management...")
def get_time() -> str:
"""Get current time."""
return "The current time is 3:00 PM."
await service.add_tool(get_time)
print(f"✅ Tool added successfully!")
print("\n🎉 All basic functionality tests passed!")
return True
except Exception as e:
print(f"❌ Test failed with error: {e}")
return False
async def test_frame_processing():
"""Test frame processing with mocked responses."""
print("\n🔄 Testing frame processing...")
try:
# Mock the Runner to avoid actual API calls
with patch("pipecat.services.openai_agent.agent_service.Runner") as mock_runner:
# Set up mock responses
mock_stream_result = MagicMock()
# Mock stream events
async def mock_stream_events():
# Simulate streaming response
yield MagicMock(type="raw_response_event", data=MagicMock(delta="Hello "))
yield MagicMock(type="raw_response_event", data=MagicMock(delta="from "))
yield MagicMock(type="raw_response_event", data=MagicMock(delta="agent!"))
# Simulate completed message
mock_item = MagicMock()
mock_item.type = "message_output_item"
mock_item.content = "Hello from agent!"
yield MagicMock(type="run_item_stream_event", item=mock_item)
mock_stream_result.stream_events.return_value = mock_stream_events()
mock_runner.run_streamed.return_value = mock_stream_result
# Create service with mocked runner
service = OpenAIAgentService(
name="Test Assistant",
instructions="You are a helpful test assistant.",
api_key="test-key",
streaming=True,
)
# Collect output frames
output_frames = []
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
output_frames.append(frame)
print(f" 📤 Frame: {type(frame).__name__}")
if hasattr(frame, "text"):
print(f" Text: '{frame.text}'")
service.push_frame = mock_push_frame
# Process a text frame
print("📝 Processing text frame...")
text_frame = TextFrame("Hello, how are you?")
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
# Wait for async processing
await asyncio.sleep(0.2)
print(f"✅ Frame processing completed!")
print(f" - Generated {len(output_frames)} output frames")
# Check if we got expected frame types
frame_types = [type(frame).__name__ for frame in output_frames]
print(f" - Frame types: {frame_types}")
return True
except Exception as e:
print(f"❌ Frame processing test failed: {e}")
return False
async def main():
"""Run all tests."""
print("🚀 Starting OpenAI Agent Service Tests\n")
try:
# Run basic functionality tests
basic_test = await test_basic_functionality()
# Run frame processing tests
frame_test = await test_frame_processing()
# Summary
print(f"\n📊 Test Results:")
print(f" - Basic functionality: {'✅ PASS' if basic_test else '❌ FAIL'}")
print(f" - Frame processing: {'✅ PASS' if frame_test else '❌ FAIL'}")
if basic_test and frame_test:
print(f"\n🎉 All tests passed! The OpenAI Agent service is working correctly.")
else:
print(f"\n⚠️ Some tests failed. Please check the output above.")
except Exception as e:
print(f"❌ Test suite failed with error: {e}")
if __name__ == "__main__":
asyncio.run(main())

33
test_simple_agent.py Normal file
View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
import asyncio
import os
from loguru import logger
# Test the actual agents package API
try:
from agents import Agent, run
# Create a simple agent
agent = Agent(
name="test-agent",
instructions="You are a helpful assistant.",
)
print("✅ Agent created successfully!")
print(f"Agent name: {agent.name}")
# Test a simple conversation
async def test_agent():
result = await run(agent, "Hello, how are you?")
print(f"Agent response: {result}")
# Run the test
asyncio.run(test_agent())
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()

View File

@@ -1,67 +0,0 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import unittest
from pipecat.frames.frames import (
EndFrame,
Frame,
InterruptionFrame,
TextFrame,
TransportMessageUrgentFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import SleepFrame, run_test
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
async def test_interruption_and_wait(self):
class DelayFrameProcessor(FrameProcessor):
"""This processors just gives time to the event loop to change
between tasks. Otherwise things happen to fast."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await asyncio.sleep(0.1)
await self.push_frame(frame, direction)
class InterruptFrameProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_interruption_task_frame_and_wait()
await self.push_frame(TransportMessageUrgentFrame(message=frame.text))
else:
await self.push_frame(frame, direction)
pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()])
frames_to_send = [
# Just a random interruption to make sure we don't clear anything
# before the actual `InterruptionTaskFrame` interruption.
InterruptionFrame(),
# This will generate an `InterruptionTaskFrame` and will wait for an
# `InterruptionFrame`.
TextFrame(text="Hello from Pipecat!"),
# Just give time for everything to complete.
SleepFrame(sleep=0.5),
EndFrame(),
]
expected_down_frames = [
InterruptionFrame,
InterruptionFrame,
TransportMessageUrgentFrame,
EndFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
send_end_frame=False,
)

View File

@@ -0,0 +1,286 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for OpenAI Agent service."""
import asyncio
import os
import sys
import unittest.mock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add src to path for testing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
TextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
class MockAgent:
"""Mock Agent for testing."""
def __init__(self, name="Test Agent", instructions="Test instructions"):
self.name = name
self.instructions = instructions
self.tools = []
self.handoffs = []
class MockRunResult:
"""Mock RunResult for testing."""
def __init__(self, final_output="Test response"):
self.final_output = final_output
class MockStreamEvent:
"""Mock StreamEvent for testing."""
def __init__(self, event_type, data=None, item=None):
self.type = event_type
self.data = data
self.item = item
class MockMessageItem:
"""Mock message item for testing."""
def __init__(self, content="Test content"):
self.type = "message_output_item"
self.content = content
class MockRunner:
"""Mock Runner for testing."""
@staticmethod
async def run(agent, input_text, context=None):
return MockRunResult("Mocked response")
@staticmethod
def run_streamed(agent, input_text, context=None):
class MockStreamResult:
async def stream_events(self):
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="Test "))
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="response"))
yield MockStreamEvent(
"run_item_stream_event", item=MockMessageItem("Test response")
)
return MockStreamResult()
@pytest.fixture
def mock_openai_agents():
"""Mock the OpenAI Agents SDK imports."""
with patch.dict(
"sys.modules",
{
"agents": MagicMock(),
"agents.stream_events": MagicMock(),
"agents.result": MagicMock(),
},
):
# Mock the classes and functions we need
mock_agent = MagicMock()
mock_agent.return_value = MockAgent()
mock_runner = MagicMock()
mock_runner.run = AsyncMock(return_value=MockRunResult())
mock_runner.run_streamed = MagicMock(return_value=MockRunner.run_streamed(None, None))
with (
patch("pipecat.services.openai_agent.agent_service.Agent", mock_agent),
patch("pipecat.services.openai_agent.agent_service.Runner", mock_runner),
):
yield {
"Agent": mock_agent,
"Runner": mock_runner,
}
@pytest.mark.asyncio
async def test_openai_agent_service_init(mock_openai_agents):
"""Test OpenAI Agent service initialization."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
)
assert service.agent.name == "Test Agent"
assert service._streaming is True
@pytest.mark.asyncio
async def test_openai_agent_service_process_text_frame_streaming(mock_openai_agents):
"""Test processing text frame with streaming enabled."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
)
# Mock the push_frame method to capture output
output_frames = []
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
output_frames.append(frame)
service.push_frame = mock_push_frame
# Process a text frame
text_frame = TextFrame("Hello, agent!")
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
# Wait a bit for async processing
await asyncio.sleep(0.1)
# Check that appropriate frames were generated
assert len(output_frames) > 0
assert any(isinstance(frame, LLMFullResponseStartFrame) for frame in output_frames)
@pytest.mark.asyncio
async def test_openai_agent_service_process_text_frame_non_streaming(mock_openai_agents):
"""Test processing text frame with streaming disabled."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=False
)
# Mock the push_frame method to capture output
output_frames = []
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
output_frames.append(frame)
service.push_frame = mock_push_frame
# Process a text frame
text_frame = TextFrame("Hello, agent!")
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
# Wait a bit for async processing
await asyncio.sleep(0.1)
# Check that appropriate frames were generated
assert len(output_frames) > 0
@pytest.mark.asyncio
async def test_openai_agent_service_update_config(mock_openai_agents):
"""Test updating agent configuration."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key"
)
# Update configuration
service.update_agent_config(
instructions="Updated instructions", model_config={"model": "gpt-4o", "temperature": 0.7}
)
assert service.agent.instructions == "Updated instructions"
assert service.agent.model_config["model"] == "gpt-4o"
@pytest.mark.asyncio
async def test_openai_agent_service_session_context(mock_openai_agents):
"""Test session context management."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent",
instructions="Test instructions",
api_key="test-key",
session_config={"user_id": "test-user"},
)
# Get initial context
context = service.get_session_context()
assert context["user_id"] == "test-user"
# Update context
service.update_session_context({"session_id": "test-session"})
updated_context = service.get_session_context()
assert updated_context["user_id"] == "test-user"
assert updated_context["session_id"] == "test-session"
@pytest.mark.asyncio
async def test_openai_agent_service_add_tools(mock_openai_agents):
"""Test adding tools to the agent."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key"
)
# Define a test tool
def test_tool():
return "test result"
# Add the tool
await service.add_tool(test_tool)
# Check if tool was added (this depends on the mock implementation)
assert hasattr(service.agent, "tools")
@pytest.mark.asyncio
async def test_openai_agent_service_lifecycle(mock_openai_agents):
"""Test service lifecycle methods."""
from pipecat.frames.frames import CancelFrame, EndFrame, StartFrame
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key"
)
# Test start
start_frame = StartFrame()
await service.start(start_frame)
# Test cancel
cancel_frame = CancelFrame()
await service.cancel(cancel_frame)
# Test stop
end_frame = EndFrame()
await service.stop(end_frame)
def test_openai_agent_service_import_error():
"""Test that import error is handled gracefully."""
# Mock the import to fail
with patch.dict("sys.modules", {"agents": None}):
with pytest.raises(Exception) as exc_info:
# This should trigger the import error
import importlib
import pipecat.services.openai_agent.agent_service
importlib.reload(pipecat.services.openai_agent.agent_service)
assert "Missing module" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,303 +0,0 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Unit tests for ServiceSwitcher and related components."""
import unittest
from pipecat.frames.frames import (
Frame,
ManuallySwitchServiceFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import run_test
class MockFrameProcessor(FrameProcessor):
"""A test frame processor that tracks which frames it has processed."""
def __init__(self, test_name: str, **kwargs):
"""Initialize the test processor with a name.
Args:
test_name: A unique name for this processor instance.
**kwargs: Additional arguments passed to the parent FrameProcessor.
"""
super().__init__(name=test_name, **kwargs)
self.test_name = test_name
self.processed_frames = []
self.frame_count = 0
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process an incoming frame and track it.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
self.processed_frames.append(frame)
self.frame_count += 1
await self.push_frame(frame, direction)
def reset_counters(self):
"""Reset the frame tracking counters."""
self.processed_frames = []
self.frame_count = 0
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcherStrategyManual."""
def setUp(self):
"""Set up test fixtures."""
self.service1 = MockFrameProcessor("service1")
self.service2 = MockFrameProcessor("service2")
self.service3 = MockFrameProcessor("service3")
self.services = [self.service1, self.service2, self.service3]
def test_init_with_services(self):
"""Test initialization with a list of services."""
strategy = ServiceSwitcherStrategyManual(self.services)
self.assertEqual(strategy.services, self.services)
self.assertEqual(strategy.active_service, self.service1) # First service should be active
def test_init_with_empty_services(self):
"""Test initialization with an empty list of services."""
strategy = ServiceSwitcherStrategyManual([])
self.assertEqual(strategy.services, [])
self.assertIsNone(strategy.active_service)
def test_handle_manually_switch_service_frame(self):
"""Test manual service switching with ManuallySwitchServiceFrame."""
strategy = ServiceSwitcherStrategyManual(self.services)
# Initially service1 should be active
self.assertEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
# Switch to service2
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertEqual(strategy.active_service, self.service2)
self.assertNotEqual(strategy.active_service, self.service3)
# Switch to service3
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
self.assertNotEqual(strategy.active_service, self.service1)
self.assertNotEqual(strategy.active_service, self.service2)
self.assertEqual(strategy.active_service, self.service3)
def test_handle_frame_unsupported_frame_type(self):
"""Test that unsupported frame types raise an error."""
strategy = ServiceSwitcherStrategyManual(self.services)
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
with self.assertRaises(ValueError) as context:
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
self.assertIn("Unsupported frame type", str(context.exception))
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
"""Test cases for ServiceSwitcher."""
def setUp(self):
"""Set up test fixtures."""
self.service1 = MockFrameProcessor("service1")
self.service2 = MockFrameProcessor("service2")
self.service3 = MockFrameProcessor("service3")
self.services = [self.service1, self.service2, self.service3]
def test_init_with_manual_strategy(self):
"""Test initialization with manual strategy."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
self.assertEqual(switcher.services, self.services)
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
self.assertEqual(switcher.strategy.services, self.services)
async def test_default_active_service(self):
"""Test that the initially-active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
# Reset counters
for service in self.services:
service.reset_counters()
# Send some test frames
frames_to_send = [
TextFrame(text="Hello 1"),
TextFrame(text="Hello 2"),
TextFrame(text="Hello 3"),
]
await run_test(
switcher,
frames_to_send=frames_to_send,
expected_down_frames=[TextFrame, TextFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
# Only service1 should have processed the text frames
# Note: The service also receives StartFrame and EndFrame, so count those too
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
self.assertEqual(len(text_frames), 3)
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
service3_text_frames = [
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
]
self.assertEqual(len(service2_text_frames), 0)
self.assertEqual(len(service3_text_frames), 0)
# Verify the actual text frames processed
for i, frame in enumerate(text_frames):
self.assertEqual(frame.text, f"Hello {i + 1}")
async def test_service_switching(self):
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
# Reset counters
for service in self.services:
service.reset_counters()
# Send a test frame, a switch frame, and another test frame
await run_test(
switcher,
frames_to_send=[
TextFrame("Hello 1"),
ManuallySwitchServiceFrame(service=self.service2),
TextFrame("Hello 2"),
],
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
expected_up_frames=[], # Expect no error frames
)
# Verify service2 received the frame
service1_text_frames = [
f for f in self.service1.processed_frames if isinstance(f, TextFrame)
]
service2_text_frames = [
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
]
service3_text_frames = [
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
]
self.assertEqual(len(service1_text_frames), 1)
self.assertEqual(len(service2_text_frames), 1)
self.assertEqual(len(service3_text_frames), 0)
self.assertEqual(service1_text_frames[0].text, "Hello 1")
self.assertEqual(service2_text_frames[0].text, "Hello 2")
async def test_multi_service_switcher_targeting(self):
"""Test that ManuallySwitchServiceFrame targets the correct ServiceSwitcher in a multi-switcher pipeline."""
# Create services for first switcher
switcher1_service1 = MockFrameProcessor("switcher1_service1")
switcher1_service2 = MockFrameProcessor("switcher1_service2")
switcher1_services = [switcher1_service1, switcher1_service2]
# Create services for second switcher
switcher2_service1 = MockFrameProcessor("switcher2_service1")
switcher2_service2 = MockFrameProcessor("switcher2_service2")
switcher2_services = [switcher2_service1, switcher2_service2]
# Create two service switchers
switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual)
switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual)
# Create a pipeline with both switchers: switcher1 -> switcher2
pipeline = Pipeline([switcher1, switcher2])
# Reset counters
for service in switcher1_services + switcher2_services:
service.reset_counters()
# Initially, both switchers should use their first services
self.assertEqual(switcher1.strategy.active_service, switcher1_service1)
self.assertEqual(switcher2.strategy.active_service, switcher2_service1)
# Send frames to test the pipeline:
# 1. Text frame (should go through both switchers' active services)
# 2. Switch frame targeting switcher1's second service
# 3. Text frame (should go through switcher1's new service and switcher2's original service)
# 4. Switch frame targeting switcher2's second service
# 5. Text frame (should go through switcher1's current service and switcher2's new service)
await run_test(
pipeline,
frames_to_send=[
TextFrame("Before any switches"),
ManuallySwitchServiceFrame(service=switcher1_service2), # Switch first switcher
TextFrame("After switching first switcher"),
ManuallySwitchServiceFrame(service=switcher2_service2), # Switch second switcher
TextFrame("After switching second switcher"),
],
expected_down_frames=[
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
ManuallySwitchServiceFrame,
TextFrame,
],
expected_up_frames=[], # Expect no error frames
)
# Verify the active services changed correctly
self.assertEqual(switcher1.strategy.active_service, switcher1_service2)
self.assertEqual(switcher2.strategy.active_service, switcher2_service2)
# Verify frame distribution:
# First text frame should go through switcher1_service1 and switcher2_service1
switcher1_service1_texts = [
f for f in switcher1_service1.processed_frames if isinstance(f, TextFrame)
]
switcher2_service1_texts = [
f for f in switcher2_service1.processed_frames if isinstance(f, TextFrame)
]
# Second text frame should go through switcher1_service2 and switcher2_service1
switcher1_service2_texts = [
f for f in switcher1_service2.processed_frames if isinstance(f, TextFrame)
]
# Third text frame should go through switcher1_service2 and switcher2_service2
switcher2_service2_texts = [
f for f in switcher2_service2.processed_frames if isinstance(f, TextFrame)
]
# Verify frame counts and content
self.assertEqual(len(switcher1_service1_texts), 1)
self.assertEqual(switcher1_service1_texts[0].text, "Before any switches")
self.assertEqual(len(switcher1_service2_texts), 2)
self.assertEqual(switcher1_service2_texts[0].text, "After switching first switcher")
self.assertEqual(switcher1_service2_texts[1].text, "After switching second switcher")
self.assertEqual(len(switcher2_service1_texts), 2)
self.assertEqual(switcher2_service1_texts[0].text, "Before any switches")
self.assertEqual(switcher2_service1_texts[1].text, "After switching first switcher")
self.assertEqual(len(switcher2_service2_texts), 1)
self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher")
if __name__ == "__main__":
unittest.main()

8191
uv.lock generated

File diff suppressed because it is too large Load Diff