Compare commits
8 Commits
cb/frame-p
...
hush/openA
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29d4a56663 | ||
|
|
373a09ecd6 | ||
|
|
07f54c48f3 | ||
|
|
c8a3d65aa4 | ||
|
|
50a2a0dc86 | ||
|
|
0421d97954 | ||
|
|
54c8f336c3 | ||
|
|
b086fbafe6 |
285
AGENTS.md
Normal file
285
AGENTS.md
Normal file
@@ -0,0 +1,285 @@
|
||||
# AGENTS.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
Pipecat is an open-source Python framework for building real-time voice and multimodal conversational AI agents. The codebase is organized around a pipeline architecture where data flows through connected services (STT → LLM → TTS).
|
||||
|
||||
## Development Environment Setup
|
||||
|
||||
### Prerequisites
|
||||
- **Minimum Python Version:** 3.10
|
||||
- **Recommended Python Version:** 3.12
|
||||
- **Package Manager:** uv (recommended) or pip
|
||||
|
||||
### Setup Commands
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/pipecat-ai/pipecat.git
|
||||
cd pipecat
|
||||
|
||||
# Install dependencies with uv (recommended)
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra ultravox
|
||||
|
||||
# Or with pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Install pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Set up environment variables
|
||||
cp env.example .env
|
||||
```
|
||||
|
||||
## Build and Test Commands
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest tests/test_name.py
|
||||
|
||||
# Run tests with coverage
|
||||
uv run pytest --cov=pipecat --cov-report=html
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Format code (required before commits)
|
||||
uv run ruff format
|
||||
|
||||
# Lint code
|
||||
uv run ruff check
|
||||
|
||||
# Type checking
|
||||
uv run mypy src/pipecat
|
||||
|
||||
# Run pre-commit checks manually
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Documentation
|
||||
```bash
|
||||
# Build API documentation
|
||||
cd docs/api
|
||||
./build-docs.sh
|
||||
|
||||
# Build docs manually
|
||||
sphinx-build -b html . _build/html -W --keep-going
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Python Standards
|
||||
- **Formatting:** Strict PEP 8 via Ruff
|
||||
- **Docstrings:** Google-style format
|
||||
- **Type Hints:** Required for all public APIs
|
||||
- **Import Organization:** Automated via Ruff
|
||||
|
||||
### Docstring Conventions
|
||||
- **Classes:** Describe purpose + `__init__` with complete `Args:` section
|
||||
- **Dataclasses:** Use `Parameters:` section, no `__init__` docstring
|
||||
- **Methods:** Include `Args:` and `Returns:` sections
|
||||
- **Properties:** Must have `Returns:` section
|
||||
- **Examples:** Use `Examples:` section with `::` syntax
|
||||
|
||||
### File Organization
|
||||
```
|
||||
src/pipecat/ # Main package
|
||||
├── processors/ # Frame processors
|
||||
├── services/ # AI service integrations
|
||||
├── transports/ # Communication layers
|
||||
├── frames/ # Data frame definitions
|
||||
└── pipeline/ # Pipeline orchestration
|
||||
|
||||
examples/foundational/ # Step-by-step tutorials
|
||||
tests/ # Test suite
|
||||
```
|
||||
|
||||
## Testing Instructions
|
||||
|
||||
### Test Structure
|
||||
- **Unit Tests:** Test individual components in isolation
|
||||
- **Integration Tests:** Test service interactions
|
||||
- **Example Tests:** Validate foundational examples work
|
||||
|
||||
### Adding Tests
|
||||
```bash
|
||||
# Test naming convention
|
||||
test_<component>_<functionality>.py
|
||||
|
||||
# Run specific test pattern
|
||||
uv run pytest -k "test_pipeline"
|
||||
|
||||
# Run with debugging
|
||||
uv run pytest -s -vv tests/test_name.py::test_function
|
||||
```
|
||||
|
||||
### Pre-commit Requirements
|
||||
All commits must pass:
|
||||
- Ruff formatting
|
||||
- Ruff linting
|
||||
- Type checking
|
||||
- Basic test suite
|
||||
|
||||
## Dependency Management
|
||||
|
||||
### Using uv (Recommended)
|
||||
```bash
|
||||
# Add runtime dependency
|
||||
uv add package-name
|
||||
|
||||
# Add optional dependency
|
||||
uv add --optional service package-name
|
||||
|
||||
# Add development dependency
|
||||
uv add --group dev package-name
|
||||
|
||||
# Update lockfile
|
||||
uv lock
|
||||
|
||||
# Sync dependencies
|
||||
uv sync
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
- **Always commit both `pyproject.toml` and `uv.lock` together**
|
||||
- **Never manually edit `uv.lock`** - it's auto-generated
|
||||
- **Use extras for optional service dependencies** (e.g., `[openai]`, `[cartesia]`)
|
||||
|
||||
## Project Structure Guidelines
|
||||
|
||||
### Service Integration
|
||||
When adding new AI services:
|
||||
1. Create service class in `src/pipecat/services/<provider>/`
|
||||
2. Follow existing patterns (e.g., STTService, LLMService)
|
||||
3. Add to appropriate extras in `pyproject.toml`
|
||||
4. Include tests in `tests/`
|
||||
5. Add documentation examples
|
||||
|
||||
### Frame Processing
|
||||
For custom processors:
|
||||
1. Inherit from `FrameProcessor`
|
||||
2. Implement `process_frame()` method. ALWAYS explicitly call `await super().process_frame(frame, direction)` at the top of this method.
|
||||
3. Handle frame direction (FrameDirection.UPSTREAM/DOWNSTREAM)
|
||||
4. Add proper type hints and docstrings
|
||||
|
||||
### Transport Implementation
|
||||
For new transport layers:
|
||||
1. Inherit from `BaseTransport`
|
||||
2. Implement required abstract methods
|
||||
3. Handle connection lifecycle
|
||||
4. Support both input and output streams
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### API Keys
|
||||
- **Never commit API keys** to the repository
|
||||
- **Use environment variables** for all secrets
|
||||
- **Reference `env.example`** for required variables
|
||||
- **Use `.env` files** for local development
|
||||
|
||||
### Input Validation
|
||||
- **Validate all external inputs** (audio, text, API responses)
|
||||
- **Sanitize user data** before processing
|
||||
- **Handle rate limiting** for external services
|
||||
- **Implement proper timeout handling**
|
||||
|
||||
## Performance Guidelines
|
||||
|
||||
### Memory Management
|
||||
- **Clean up resources** in transport disconnection handlers
|
||||
- **Use async context managers** for service connections
|
||||
- **Implement proper frame lifecycle** management
|
||||
|
||||
### Latency Optimization
|
||||
- **Choose appropriate STT services** for latency requirements
|
||||
- **Use streaming TTS** when possible
|
||||
- **Implement connection pooling** for HTTP services
|
||||
- **Consider WebRTC** for real-time applications
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Error Handling
|
||||
```python
|
||||
@transport.event_handler("on_error")
|
||||
async def on_error(transport, error):
|
||||
logger.error(f"Transport error: {error}")
|
||||
|
||||
# Shutdown the pipeline
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
```
|
||||
|
||||
### Service Configuration
|
||||
```python
|
||||
# Use environment variables for configuration
|
||||
service = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY", ""),
|
||||
model="gpt-4o",
|
||||
params={"temperature": 0.7}
|
||||
)
|
||||
```
|
||||
|
||||
### Pipeline Assembly
|
||||
```python
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt_service,
|
||||
context_aggregator.user(),
|
||||
llm_service,
|
||||
tts_service,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
```
|
||||
|
||||
## Commit and PR Guidelines
|
||||
|
||||
### Commit Message Format
|
||||
```
|
||||
<type>(<scope>): <description>
|
||||
|
||||
[optional body]
|
||||
|
||||
[optional footer]
|
||||
```
|
||||
|
||||
Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
|
||||
|
||||
### PR Requirements
|
||||
- **All tests must pass**
|
||||
- **Code must be properly formatted** (Ruff)
|
||||
- **Include appropriate tests** for new functionality
|
||||
- **Update documentation** if needed
|
||||
- **Reference related issues** in description
|
||||
|
||||
### Review Process
|
||||
1. Automated checks must pass
|
||||
2. Manual code review by maintainers
|
||||
3. Documentation review for user-facing changes
|
||||
4. Integration testing for service additions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
- **Import errors:** Run `uv sync` to ensure dependencies are installed
|
||||
- **Test failures:** Check environment variables in `.env`
|
||||
- **Format errors:** Run `uv run ruff format` before committing
|
||||
- **Type errors:** Ensure all public methods have type hints
|
||||
|
||||
### Development Tips
|
||||
- **Use foundational examples** as starting points for testing
|
||||
- **Check existing services** for integration patterns
|
||||
- **Run tests frequently** during development
|
||||
- **Use IDE integration** for Ruff formatting
|
||||
|
||||
### Getting Help
|
||||
- **Documentation:** [docs.pipecat.ai](https://docs.pipecat.ai)
|
||||
- **Issues:** [GitHub Issues](https://github.com/pipecat-ai/pipecat/issues)
|
||||
67
CHANGELOG.md
67
CHANGELOG.md
@@ -9,23 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_before_disconnect` synchronous event to `DailyTransport` and
|
||||
`LiveKitTransport`.
|
||||
|
||||
- It is now possible to register synchronous event handlers. By default, all
|
||||
event handlers are executed in a separate task. However, in some cases we want
|
||||
to guarantee order of execution, for example, executing something before
|
||||
disconnecting a transport.
|
||||
|
||||
```python
|
||||
self._register_event_handler("on_event_name", sync=True)
|
||||
```
|
||||
|
||||
- Added support for global location in `GoogleVertexLLMService`. The service now
|
||||
supports both regional locations (e.g., "us-east4") and the "global" location
|
||||
for Vertex AI endpoints. When using "global" location, the service will use
|
||||
`aiplatform.googleapis.com` as the API host instead of the regional format.
|
||||
|
||||
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
|
||||
fired when the pipeline is done running. This can be the result of a
|
||||
`StopFrame`, `CancelFrame` or `EndFrame`.
|
||||
@@ -36,64 +19,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
...
|
||||
```
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Silero VAD model to v6.
|
||||
|
||||
- Updated `livekit` to 1.0.13.
|
||||
|
||||
- `torch` and `torchaudio` are no longer required for running Smart Turn
|
||||
locally. This avoids gigabytes of dependencies being installed.
|
||||
|
||||
- Updated `websockets` dependency to support version 15.0. Removed deprecated
|
||||
usage of `ConnectionClosed.code` and `ConnectionClosed.reason` attributes in
|
||||
`AWSTranscribeSTTService` for compatibility.
|
||||
|
||||
- Refactored `pyproject.toml` to reduce websockets dependency repetition using
|
||||
self-referencing extras. All websockets-dependent services now reference a
|
||||
shared `websockets-base` extra.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `GladiaSTTService`'s `confidence` arg is deprecated. `confidence` is no
|
||||
longer needed to determine which transcription or translation frames to
|
||||
emit.
|
||||
|
||||
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
|
||||
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
|
||||
instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where multiple handlers for an event would not run in parallel.
|
||||
|
||||
- Fixed `DailyTransport.sip_call_transfer()` to automatically use the session
|
||||
ID from the `on_dialin_connected` event, when not explicitly provided. Now
|
||||
supports cold transfers (from incoming dial-in calls) by automatically
|
||||
tracking session IDs from connection events.
|
||||
|
||||
- Fixed a memory leak in `SmallWebRTCTransport`. In `aiortc`, when you receive
|
||||
a `MediaStreamTrack` (audio or video), frames are produced asynchronously. If
|
||||
the code never consumes these frames, they are queued in memory, causing a
|
||||
memory leak.
|
||||
|
||||
- Fixed an issue in `AsyncAITTSService`, where `TTSTextFrames` were not being
|
||||
pushed.
|
||||
|
||||
- Fixed an issue that would cause `push_interruption_task_frame_and_wait()` to
|
||||
not wait if a previous interruption had already happened.
|
||||
|
||||
- Fixed a couple of bugs in `ServiceSwitcher`:
|
||||
|
||||
- Using multiple `ServiceSwitcher`s in a pipeline would result in an error.
|
||||
- `ServiceSwitcherFrame`s (such as `ManuallySwitchServiceFrame`s) were having
|
||||
an effect too early, essentially "jumping the queue" in terms of pipeline
|
||||
frame ordering.
|
||||
|
||||
- Fixed a self-cancellation deadlock in `UserIdleProcessor` when returning
|
||||
`False` from an idle callback. The task now terminates naturally instead of
|
||||
attempting to cancel itself.
|
||||
|
||||
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
|
||||
when a bot speaks and user input is blocked.
|
||||
|
||||
|
||||
@@ -21,8 +21,6 @@
|
||||
|
||||
🧭 Looking to build structured conversations? Check out [Pipecat Flows](https://github.com/pipecat-ai/pipecat-flows) for managing complex conversational states and transitions.
|
||||
|
||||
🔍 Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
|
||||
|
||||
## 🧠 Why Pipecat?
|
||||
|
||||
- **Voice-first**: Integrates speech recognition, text-to-speech, and conversation handling
|
||||
|
||||
@@ -11,7 +11,7 @@ import sys
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
@@ -50,7 +50,7 @@ async def main():
|
||||
async def on_first_participant_joined(transport, participant_id):
|
||||
await asyncio.sleep(1)
|
||||
await task.queue_frame(
|
||||
TTSSpeakFrame(
|
||||
TextFrame(
|
||||
"Hello there! How are you doing today? Would you like to talk about the weather?"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,6 +30,10 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
|
||||
# to the Smart Turn v3 ONNX model file.
|
||||
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -38,19 +42,25 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
205
examples/foundational/45-openai-agent-basic.py
Normal file
205
examples/foundational/45-openai-agent-basic.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Basic OpenAI Agent service example.
|
||||
|
||||
This example demonstrates how to use the OpenAI Agents SDK within a Pipecat
|
||||
pipeline to create an interactive agent with tool calling capabilities.
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key
|
||||
- OpenAI Agents SDK: pip install openai-agents
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, List
|
||||
|
||||
# Import agents SDK for tools and agent creation
|
||||
from agents import Agent, function_tool
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.frames.frames import LLMRunFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Transport configuration
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the current weather for a location.
|
||||
|
||||
Args:
|
||||
location: The location to get weather for
|
||||
|
||||
Returns:
|
||||
A weather description string
|
||||
"""
|
||||
# Mock weather data - in real usage, integrate with weather API
|
||||
weather_data = {
|
||||
"San Francisco": "Foggy, 65°F",
|
||||
"New York": "Sunny, 72°F",
|
||||
"London": "Rainy, 59°F",
|
||||
"Tokyo": "Partly cloudy, 68°F",
|
||||
}
|
||||
return weather_data.get(location, f"Weather data not available for {location}")
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact.
|
||||
|
||||
Returns:
|
||||
A random fact string
|
||||
"""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
|
||||
def get_random_fact_tool():
|
||||
"""Example tool function for random facts."""
|
||||
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact.
|
||||
|
||||
Returns:
|
||||
A random fact string.
|
||||
"""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"A group of flamingos is called a 'flamboyance'.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
return get_random_fact
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting OpenAI Agent bot")
|
||||
|
||||
# Set up STT for speech recognition
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
model="nova-2",
|
||||
)
|
||||
|
||||
# Set up TTS for voice output
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY", ""),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Create tools for the agent
|
||||
tools: list[Any] = [
|
||||
get_weather,
|
||||
get_random_fact,
|
||||
]
|
||||
|
||||
# Create the agent with tools
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="""You are a helpful assistant with access to weather information and random facts.
|
||||
You can:
|
||||
- Check weather for any location using the get_weather tool
|
||||
- Share interesting facts using the get_random_fact tool
|
||||
- Have natural conversations
|
||||
|
||||
Be friendly, informative, and engaging in your responses.""",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Initialize the OpenAI Agent service with the pre-configured agent
|
||||
agent_service = OpenAIAgentService(
|
||||
agent=agent,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Set up conversation context with initial system message
|
||||
messages: List[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant with access to weather information and random facts. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = agent_service.create_context_aggregator(context)
|
||||
|
||||
# Create the processing pipeline with context aggregators
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # Speech to text
|
||||
context_aggregator.user(), # User responses
|
||||
agent_service, # OpenAI Agent processing
|
||||
tts, # Text to speech
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Send an initial greeting when client connects
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected, sending greeting")
|
||||
# Kick off the conversation by adding system message and running LLM
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
276
examples/foundational/46-openai-agent-handoffs.py
Normal file
276
examples/foundational/46-openai-agent-handoffs.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Advanced OpenAI Agent service example with handoffs.
|
||||
|
||||
This example demonstrates how to use multiple agents with handoffs in the
|
||||
OpenAI Agents SDK within a Pipecat pipeline, showcasing agent orchestration
|
||||
and specialization.
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key
|
||||
- OpenAI Agents SDK: pip install openai-agents
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.frames.frames import LLMRunFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Transport configuration
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
def create_weather_tools():
|
||||
"""Create weather-related tools."""
|
||||
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get current weather for a location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "snowy", "windy"]
|
||||
temp = random.randint(-10, 35)
|
||||
condition = random.choice(conditions)
|
||||
return f"The weather in {location} is {condition} with a temperature of {temp}°C."
|
||||
|
||||
def get_forecast(location: str, days: int = 3) -> str:
|
||||
"""Get weather forecast for multiple days."""
|
||||
forecast = []
|
||||
for i in range(days):
|
||||
conditions = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
temp = random.randint(-5, 30)
|
||||
condition = random.choice(conditions)
|
||||
day = "today" if i == 0 else f"in {i} day{'s' if i > 1 else ''}"
|
||||
forecast.append(f"{day.capitalize()}: {condition}, {temp}°C")
|
||||
return f"Weather forecast for {location}:\n" + "\n".join(forecast)
|
||||
|
||||
return [get_weather, get_forecast]
|
||||
|
||||
|
||||
def create_trivia_tools():
|
||||
"""Create trivia and fact tools."""
|
||||
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact."""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"A group of flamingos is called a 'flamboyance'.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
"Wombat poop is cube-shaped.",
|
||||
"A shrimp's heart is in its head.",
|
||||
"It's impossible to hum while holding your nose.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
def get_science_fact() -> str:
|
||||
"""Get a random science fact."""
|
||||
facts = [
|
||||
"The speed of light in a vacuum is approximately 299,792,458 meters per second.",
|
||||
"DNA stands for Deoxyribonucleic Acid.",
|
||||
"The human brain uses about 20% of the body's total energy.",
|
||||
"There are more possible games of chess than atoms in the observable universe.",
|
||||
"A single bolt of lightning contains enough energy to toast 100,000 slices of bread.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
return [get_random_fact, get_science_fact]
|
||||
|
||||
|
||||
def create_math_tools():
|
||||
"""Create math calculation tools."""
|
||||
|
||||
def calculate(expression: str) -> str:
|
||||
"""Safely calculate a mathematical expression."""
|
||||
try:
|
||||
# Only allow basic math operations for safety
|
||||
allowed_chars = set("0123456789+-*/.() ")
|
||||
if not all(c in allowed_chars for c in expression):
|
||||
return "Sorry, I can only calculate basic math expressions with +, -, *, /, and parentheses."
|
||||
|
||||
result = eval(expression)
|
||||
return f"{expression} = {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating '{expression}': {str(e)}"
|
||||
|
||||
def generate_math_problem() -> str:
|
||||
"""Generate a random math problem."""
|
||||
operations = ["+", "-", "*"]
|
||||
a = random.randint(1, 20)
|
||||
b = random.randint(1, 20)
|
||||
op = random.choice(operations)
|
||||
|
||||
if op == "+":
|
||||
answer = a + b
|
||||
elif op == "-":
|
||||
answer = a - b
|
||||
else: # multiplication
|
||||
answer = a * b
|
||||
|
||||
return f"Here's a math problem for you: {a} {op} {b} = ?"
|
||||
|
||||
return [calculate, generate_math_problem]
|
||||
|
||||
|
||||
async def create_specialist_agents():
|
||||
"""Create specialized agents for different domains."""
|
||||
|
||||
# Weather specialist agent
|
||||
weather_agent = OpenAIAgentService(
|
||||
name="Weather Specialist",
|
||||
instructions="""You are a weather specialist. You provide detailed weather information,
|
||||
forecasts, and weather-related advice. Use your tools to get accurate weather data.
|
||||
Be informative and helpful about weather conditions and what they might mean for
|
||||
outdoor activities.""",
|
||||
tools=create_weather_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Trivia specialist agent
|
||||
trivia_agent = OpenAIAgentService(
|
||||
name="Trivia Master",
|
||||
instructions="""You are a trivia and facts specialist. You love sharing interesting
|
||||
facts, trivia, and educational content. Use your tools to provide fascinating
|
||||
information and engage users with fun facts. Make learning enjoyable!""",
|
||||
tools=create_trivia_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Math specialist agent
|
||||
math_agent = OpenAIAgentService(
|
||||
name="Math Helper",
|
||||
instructions="""You are a mathematics specialist. You help with calculations,
|
||||
math problems, and mathematical concepts. Use your tools to solve problems
|
||||
and generate practice questions. Make math accessible and fun!""",
|
||||
tools=create_math_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return weather_agent, trivia_agent, math_agent
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting OpenAI Agent bot with handoffs")
|
||||
|
||||
# Set up STT for speech recognition
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
model="nova-2",
|
||||
)
|
||||
|
||||
# Set up TTS for voice output
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY", ""),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Create specialist agents
|
||||
weather_agent, trivia_agent, math_agent = await create_specialist_agents()
|
||||
|
||||
# Create the main triage agent that can hand off to specialists
|
||||
triage_agent = OpenAIAgentService(
|
||||
name="Assistant Coordinator",
|
||||
instructions="""You are a helpful assistant coordinator. Your role is to understand
|
||||
what the user needs and direct them to the right specialist:
|
||||
|
||||
- For weather questions, forecasts, or outdoor activity planning -> Weather Specialist
|
||||
- For interesting facts, trivia, or educational content -> Trivia Master
|
||||
- For calculations, math problems, or mathematical help -> Math Helper
|
||||
|
||||
If the request doesn't clearly fit a specialist, you can handle general conversation
|
||||
yourself. Always be friendly and explain when you're connecting them to a specialist.""",
|
||||
handoffs=[weather_agent.agent, trivia_agent.agent, math_agent.agent], # type: ignore
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Set up conversation context with initial system message
|
||||
messages: List[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant coordinator with access to weather information, trivia, and math tools. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = triage_agent.create_context_aggregator(context)
|
||||
|
||||
# Create the processing pipeline with context aggregators
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # Speech to text
|
||||
context_aggregator.user(), # User responses
|
||||
triage_agent, # OpenAI Agent processing
|
||||
tts, # Text to speech
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Send an initial greeting when client connects
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected, sending greeting")
|
||||
# Kick off the conversation by adding system message and running LLM
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please introduce yourself to the user as an AI assistant coordinator who works with specialists for weather, trivia, and math topics.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -34,7 +34,7 @@ dependencies = [
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai>=1.74.0,<=1.99.1",
|
||||
"openai>=1.74.0,<2.0.0",
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
"wait_for2>=0.4.1; python_version<'3.12'",
|
||||
@@ -47,68 +47,68 @@ Website = "https://pipecat.ai"
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.0.1" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
|
||||
assemblyai = [ "websockets>=13.1,<15.0" ]
|
||||
asyncai = [ "websockets>=13.1,<15.0" ]
|
||||
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.19.9" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
elevenlabs = [ "websockets>=13.1,<15.0" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "pipecat-ai[websockets-base]" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
|
||||
gladia = [ "websockets>=13.1,<15.0" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
|
||||
heygen = [ "livekit>=0.22.0", "websockets>=13.1,<15.0" ]
|
||||
inworld = []
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "pipecat-ai[websockets-base]" ]
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "websockets>=13.1,<15.0" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mcp = [ "mcp[cli]>=1.11.0,<2.0.0" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
nim = []
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openai = [ "websockets>=13.1,<15.0" ]
|
||||
openai-agent = [ "openai-agents~=0.3.0" ]
|
||||
# openpipe = [ "openpipe~=4.50.0" ] # Temporarily disabled due to openai version conflict
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
playht = [ "websockets>=13.1,<15.0" ]
|
||||
qwen = []
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
rime = [ "websockets>=13.1,<15.0" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sambanova = []
|
||||
sarvam = [ "pipecat-ai[websockets-base]" ]
|
||||
sarvam = [ "websockets>=13.1,<15.0" ]
|
||||
sentry = [ "sentry-sdk~=2.23.1" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1, <2" ]
|
||||
local-smart-turn-v3 = [ "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3", "onnxruntime>=1.20.1, <2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime>=1.20.1, <2" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soniox = [ "websockets>=13.1,<15.0" ]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.4.0" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
|
||||
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
|
||||
webrtc = [ "aiortc~=1.13.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.117.0" ]
|
||||
websockets-base = [ "websockets>=13.1,<16.0" ]
|
||||
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
|
||||
whisper = [ "faster-whisper~=1.1.1" ]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
PID=$1
|
||||
|
||||
while true; do
|
||||
# Clear the screen
|
||||
clear
|
||||
# Print the header + RSS in GB
|
||||
ps -p "$PID" -o pid,comm,rss | \
|
||||
awk 'NR==1 {print $0, "rss_GB"} NR>1 {printf "%s %s %s %.2f\n", $1,$2,$3,$3/1024/1024}'
|
||||
sleep 1
|
||||
done
|
||||
@@ -98,15 +98,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=8 * 16000,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
|
||||
# Extract features and ensure correct shape for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).astype(np.float32)
|
||||
# Convert to numpy and ensure correct shape for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32)
|
||||
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
|
||||
|
||||
# Run ONNX inference
|
||||
|
||||
Binary file not shown.
@@ -1604,7 +1604,7 @@ class MixerEnableFrame(MixerControlFrame):
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFrame(ControlFrame):
|
||||
"""A base class for frames that affect ServiceSwitcher behavior."""
|
||||
"""A base class for frames that control ServiceSwitcher behavior."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -6,15 +6,9 @@
|
||||
|
||||
"""Service switcher for switching between different services at runtime, with different switching strategies."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
ServiceSwitcherFrame,
|
||||
)
|
||||
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, ServiceSwitcherFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -28,6 +22,19 @@ class ServiceSwitcherStrategy:
|
||||
self.services = services
|
||||
self.active_service: Optional[FrameProcessor] = None
|
||||
|
||||
def is_active(self, service: FrameProcessor) -> bool:
|
||||
"""Determine if the given service is the currently active one.
|
||||
|
||||
This method should be overridden by subclasses to implement specific logic.
|
||||
|
||||
Args:
|
||||
service: The service to check.
|
||||
|
||||
Returns:
|
||||
True if the given service is the active one, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
@@ -53,6 +60,17 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
super().__init__(services)
|
||||
self.active_service = services[0] if services else None
|
||||
|
||||
def is_active(self, service: FrameProcessor) -> bool:
|
||||
"""Check if the given service is the currently active one.
|
||||
|
||||
Args:
|
||||
service: The service to check.
|
||||
|
||||
Returns:
|
||||
True if the given service is the active one, False otherwise.
|
||||
"""
|
||||
return service == self.active_service
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
@@ -61,21 +79,20 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
direction: The direction of the frame (upstream or downstream).
|
||||
"""
|
||||
if isinstance(frame, ManuallySwitchServiceFrame):
|
||||
self._set_active_if_available(frame.service)
|
||||
self._set_active(frame.service)
|
||||
else:
|
||||
raise ValueError(f"Unsupported frame type: {type(frame)}")
|
||||
|
||||
def _set_active_if_available(self, service: FrameProcessor):
|
||||
"""Set the active service to the given one, if it is in the list of available services.
|
||||
|
||||
If it's not in the list, the request is ignored, as it may have been
|
||||
intended for another ServiceSwitcher in the pipeline.
|
||||
def _set_active(self, service: FrameProcessor):
|
||||
"""Set the active service to the given one.
|
||||
|
||||
Args:
|
||||
service: The service to set as active.
|
||||
"""
|
||||
if service in self.services:
|
||||
self.active_service = service
|
||||
else:
|
||||
raise ValueError(f"Service {service} is not in the list of available services.")
|
||||
|
||||
|
||||
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
|
||||
@@ -91,43 +108,6 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
self.services = services
|
||||
self.strategy = strategy
|
||||
|
||||
class ServiceSwitcherFilter(FunctionFilter):
|
||||
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapped_service: FrameProcessor,
|
||||
active_service: FrameProcessor,
|
||||
direction: FrameDirection,
|
||||
):
|
||||
"""Initialize the service switcher filter with a strategy and direction."""
|
||||
|
||||
async def filter(_: Frame) -> bool:
|
||||
return self._wrapped_service == self._active_service
|
||||
|
||||
super().__init__(filter, direction)
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process a frame through the filter, handling special internal filter-updating frames."""
|
||||
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
|
||||
self._active_service = frame.active_service
|
||||
# Two ServiceSwitcherFilters "sandwich" a service. Push the
|
||||
# frame only to update the other side of the sandwich, but
|
||||
# otherwise don't let it leave the sandwich.
|
||||
if direction == self._direction:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFilterFrame(ControlFrame):
|
||||
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
|
||||
|
||||
active_service: FrameProcessor
|
||||
|
||||
@staticmethod
|
||||
def _make_pipeline_definitions(
|
||||
services: List[FrameProcessor], strategy: ServiceSwitcherStrategy
|
||||
@@ -141,18 +121,14 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
def _make_pipeline_definition(
|
||||
service: FrameProcessor, strategy: ServiceSwitcherStrategy
|
||||
) -> Any:
|
||||
async def filter(frame) -> bool:
|
||||
_ = frame
|
||||
return strategy.is_active(service)
|
||||
|
||||
return [
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
direction=FrameDirection.DOWNSTREAM,
|
||||
),
|
||||
FunctionFilter(filter, direction=FrameDirection.DOWNSTREAM),
|
||||
service,
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
direction=FrameDirection.UPSTREAM,
|
||||
),
|
||||
FunctionFilter(filter, direction=FrameDirection.UPSTREAM),
|
||||
]
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -166,7 +142,3 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
|
||||
if isinstance(frame, ServiceSwitcherFrame):
|
||||
self.strategy.handle_frame(frame, direction)
|
||||
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
|
||||
active_service=self.strategy.active_service
|
||||
)
|
||||
await super().process_frame(service_switcher_filter_frame, direction)
|
||||
|
||||
@@ -220,11 +220,6 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_event: Optional[asyncio.Event] = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
|
||||
# Then we wait for the corresponding `InterruptionFrame` to travel from
|
||||
# the start of the pipeline back to the processor that sent the
|
||||
# `InterruptionTaskFrame`. This wait is handled using the following
|
||||
# event.
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
@@ -568,17 +563,11 @@ class FrameProcessor(BaseObject):
|
||||
"""Pause processing of queued frames."""
|
||||
logger.trace(f"{self}: pausing frame processing")
|
||||
self.__should_block_frames = True
|
||||
# We should also unset the process event here, in case it was set immediately after an interruption
|
||||
if self.__process_event:
|
||||
self.__process_event.clear()
|
||||
|
||||
async def pause_processing_system_frames(self):
|
||||
"""Pause processing of queued system frames."""
|
||||
logger.trace(f"{self}: pausing system frame processing")
|
||||
self.__should_block_system_frames = True
|
||||
# We should also unset the input event here, in case it was set immediately after an interruption
|
||||
if self.__input_event:
|
||||
self.__input_event.clear()
|
||||
|
||||
async def resume_processing_frames(self):
|
||||
"""Resume processing of queued frames."""
|
||||
@@ -643,9 +632,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
|
||||
# If we are waiting for an interruption and we get an interruption, then
|
||||
# we can unblock `push_interruption_task_frame_and_wait()`.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
|
||||
@@ -17,6 +17,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -184,13 +185,15 @@ class UserIdleProcessor(FrameProcessor):
|
||||
|
||||
Runs in a loop until cancelled or callback indicates completion.
|
||||
"""
|
||||
running = True
|
||||
while running:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._interrupted:
|
||||
self._retry_count += 1
|
||||
running = await self._callback(self, self._retry_count)
|
||||
should_continue = await self._callback(self, self._retry_count)
|
||||
if not should_continue:
|
||||
await self._stop()
|
||||
break
|
||||
finally:
|
||||
self._idle_event.clear()
|
||||
|
||||
@@ -70,6 +70,7 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -182,14 +183,13 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.request_handler import (
|
||||
SmallWebRTCRequest,
|
||||
SmallWebRTCRequestHandler,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"WebRTC transport dependencies not installed: {e}")
|
||||
return
|
||||
|
||||
# Store connections by pc_id
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
# Mount the frontend
|
||||
app.mount("/client", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -198,33 +198,51 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
"""Redirect root requests to client interface."""
|
||||
return RedirectResponse(url="/client/")
|
||||
|
||||
# Initialize the SmallWebRTC request handler
|
||||
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
|
||||
esp32_mode=esp32_mode, host=host
|
||||
)
|
||||
|
||||
@app.post("/api/offer")
|
||||
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
|
||||
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
|
||||
async def offer(request: dict, background_tasks: BackgroundTasks):
|
||||
"""Handle WebRTC offer requests and manage peer connections."""
|
||||
pc_id = request.get("pc_id")
|
||||
|
||||
if pc_id and pc_id in pcs_map:
|
||||
pipecat_connection = pcs_map[pc_id]
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
await pipecat_connection.renegotiate(
|
||||
sdp=request["sdp"],
|
||||
type=request["type"],
|
||||
restart_pc=request.get("restart_pc", False),
|
||||
)
|
||||
else:
|
||||
pipecat_connection = SmallWebRTCConnection()
|
||||
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
|
||||
|
||||
@pipecat_connection.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
"""Handle WebRTC connection closure and cleanup."""
|
||||
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
|
||||
pcs_map.pop(webrtc_connection.pc_id, None)
|
||||
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
answer = await small_webrtc_handler.handle_web_request(
|
||||
request=request,
|
||||
webrtc_connection_callback=webrtc_connection_callback,
|
||||
)
|
||||
answer = pipecat_connection.get_answer()
|
||||
|
||||
# Apply ESP32 SDP munging if enabled
|
||||
if esp32_mode and host != "localhost":
|
||||
from pipecat.runner.utils import smallwebrtc_sdp_munging
|
||||
|
||||
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], host)
|
||||
|
||||
pcs_map[answer["pc_id"]] = pipecat_connection
|
||||
return answer
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage FastAPI application lifecycle and cleanup connections."""
|
||||
yield
|
||||
await small_webrtc_handler.close()
|
||||
coros = [pc.disconnect() for pc in pcs_map.values()]
|
||||
await asyncio.gather(*coros)
|
||||
pcs_map.clear()
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
|
||||
@@ -119,6 +119,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -532,7 +532,9 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
logger.error(
|
||||
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
|
||||
@@ -13,7 +13,6 @@ supporting multiple languages, custom vocabulary, and various audio processing o
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -174,6 +173,8 @@ class _InputParamsDescriptor:
|
||||
"""Descriptor for backward compatibility with deprecation warning."""
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
@@ -207,7 +208,7 @@ class GladiaSTTService(STTService):
|
||||
api_key: str,
|
||||
region: Literal["us-west", "eu-west"] | None = None,
|
||||
url: str = "https://api.gladia.io/v2/live",
|
||||
confidence: Optional[float] = None,
|
||||
confidence: float = 0.5,
|
||||
sample_rate: Optional[int] = None,
|
||||
model: str = "solaria-1",
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
@@ -223,11 +224,6 @@ class GladiaSTTService(STTService):
|
||||
region: Region used to process audio. eu-west or us-west. Defaults to eu-west.
|
||||
url: Gladia API URL. Defaults to "https://api.gladia.io/v2/live".
|
||||
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
The 'confidence' parameter is deprecated and will be removed in a future version.
|
||||
No confidence threshold is applied.
|
||||
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
model: Model to use for transcription. Defaults to "solaria-1".
|
||||
params: Additional configuration parameters for Gladia service.
|
||||
@@ -240,6 +236,7 @@ class GladiaSTTService(STTService):
|
||||
|
||||
params = params or GladiaInputParams()
|
||||
|
||||
# Warn about deprecated language parameter if it's used
|
||||
if params.language is not None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
@@ -250,20 +247,11 @@ class GladiaSTTService(STTService):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if confidence:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The 'confidence' parameter is deprecated and will be removed in a future version. "
|
||||
"No confidence threshold is applied.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._region = region
|
||||
self._url = url
|
||||
self.set_model_name(model)
|
||||
self._confidence = confidence
|
||||
self._params = params
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
@@ -587,40 +575,43 @@ class GladiaSTTService(STTService):
|
||||
|
||||
elif content["type"] == "transcript":
|
||||
utterance = content["data"]["utterance"]
|
||||
confidence = utterance.get("confidence", 0)
|
||||
language = utterance["language"]
|
||||
transcript = utterance["text"]
|
||||
is_final = content["data"]["is_final"]
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
if confidence >= self._confidence:
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=is_final,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=is_final,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif content["type"] == "translation":
|
||||
translated_utterance = content["data"]["translated_utterance"]
|
||||
original_language = content["data"]["original_language"]
|
||||
translated_language = translated_utterance["language"]
|
||||
confidence = translated_utterance.get("confidence", 0)
|
||||
translation = translated_utterance["text"]
|
||||
if translated_language != original_language:
|
||||
if translated_language != original_language and confidence >= self._confidence:
|
||||
await self.push_frame(
|
||||
TranslationFrame(
|
||||
translation, "", time_now_iso8601(), translated_language
|
||||
|
||||
@@ -83,23 +83,14 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
self._api_key = self._get_api_token(credentials, credentials_path)
|
||||
|
||||
super().__init__(
|
||||
api_key=self._api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
params=params,
|
||||
**kwargs,
|
||||
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_base_url(params: InputParams) -> str:
|
||||
"""Construct the base URL for Vertex AI API."""
|
||||
# Determine the correct API host based on location
|
||||
if params.location == "global":
|
||||
api_host = "aiplatform.googleapis.com"
|
||||
else:
|
||||
api_host = f"{params.location}-aiplatform.googleapis.com"
|
||||
return (
|
||||
f"https://{api_host}/v1/"
|
||||
f"https://{params.location}-aiplatform.googleapis.com/v1/"
|
||||
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
|
||||
)
|
||||
|
||||
@@ -127,14 +118,12 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
if credentials:
|
||||
# Parse and load credentials from JSON string
|
||||
creds = service_account.Credentials.from_service_account_info(
|
||||
json.loads(credentials),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
elif credentials_path:
|
||||
# Load credentials from JSON file
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
credentials_path,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, TypeAlias
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -28,8 +28,6 @@ except ModuleNotFoundError as e:
|
||||
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
|
||||
|
||||
class MCPClient(BaseObject):
|
||||
"""Client for Model Context Protocol (MCP) servers.
|
||||
@@ -44,7 +42,7 @@ class MCPClient(BaseObject):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
209
src/pipecat/services/openai_agent/README.md
Normal file
209
src/pipecat/services/openai_agent/README.md
Normal file
@@ -0,0 +1,209 @@
|
||||
# OpenAI Agents SDK Integration
|
||||
|
||||
This service integrates the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) with Pipecat, enabling powerful agentic workflows with features like:
|
||||
|
||||
- **Agent loops** with tool calling and response streaming
|
||||
- **Handoffs** between specialized agents
|
||||
- **Guardrails** for input/output validation
|
||||
- **Sessions** with automatic conversation history
|
||||
- **Built-in tracing** and monitoring
|
||||
|
||||
## Installation
|
||||
|
||||
Install the OpenAI Agents SDK dependency:
|
||||
|
||||
```bash
|
||||
pip install "pipecat-ai[openai-agent]"
|
||||
# or
|
||||
uv add "pipecat-ai[openai-agent]"
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from pipecat.services.openai_agent import OpenAIAgentService
|
||||
|
||||
# Create a simple agent
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Use in a pipeline
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
agent_service,
|
||||
tts,
|
||||
transport.output(),
|
||||
])
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Tool Integration
|
||||
|
||||
```python
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny, 22°C"
|
||||
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Weather Assistant",
|
||||
instructions="Help users with weather information.",
|
||||
tools=[get_weather],
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
```
|
||||
|
||||
### Agent Handoffs
|
||||
|
||||
```python
|
||||
# Create specialized agents
|
||||
weather_agent = OpenAIAgentService(
|
||||
name="Weather Specialist",
|
||||
instructions="Provide weather information and forecasts.",
|
||||
tools=[get_weather, get_forecast],
|
||||
)
|
||||
|
||||
trivia_agent = OpenAIAgentService(
|
||||
name="Trivia Master",
|
||||
instructions="Share interesting facts and trivia.",
|
||||
tools=[get_random_fact],
|
||||
)
|
||||
|
||||
# Create coordinator that can hand off to specialists
|
||||
coordinator = OpenAIAgentService(
|
||||
name="Coordinator",
|
||||
instructions="Route users to the right specialist.",
|
||||
handoffs=[weather_agent.agent, trivia_agent.agent],
|
||||
)
|
||||
```
|
||||
|
||||
### Guardrails
|
||||
|
||||
```python
|
||||
from agents import InputGuardrail, GuardrailFunctionOutput
|
||||
|
||||
async def content_filter(ctx, agent, input_data):
|
||||
# Check input for appropriate content
|
||||
if is_inappropriate(input_data):
|
||||
return GuardrailFunctionOutput(
|
||||
tripwire_triggered=True,
|
||||
output_info="Content not allowed"
|
||||
)
|
||||
return GuardrailFunctionOutput(tripwire_triggered=False)
|
||||
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Safe Assistant",
|
||||
instructions="You are a helpful and safe assistant.",
|
||||
input_guardrails=[InputGuardrail(guardrail_function=content_filter)],
|
||||
)
|
||||
```
|
||||
|
||||
### Session Management
|
||||
|
||||
```python
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Personal Assistant",
|
||||
instructions="Remember user preferences and context.",
|
||||
session_config={
|
||||
"user_id": "user_123",
|
||||
"memory_enabled": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Update session context dynamically
|
||||
agent_service.update_session_context({
|
||||
"user_preferences": {"language": "en", "style": "formal"}
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Basic Parameters
|
||||
|
||||
- `name`: Agent identifier for handoffs and tracing
|
||||
- `instructions`: System prompt defining agent behavior
|
||||
- `api_key`: OpenAI API key (or use `OPENAI_API_KEY` env var)
|
||||
- `streaming`: Enable real-time token streaming (default: True)
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
- `tools`: List of callable functions for the agent to use
|
||||
- `handoffs`: List of other agents this agent can transfer to
|
||||
- `input_guardrails`: Input validation and filtering
|
||||
- `output_guardrails`: Output validation and filtering
|
||||
- `model_config`: Model settings (model, temperature, etc.)
|
||||
- `session_config`: Session and memory configuration
|
||||
|
||||
### Model Configuration
|
||||
|
||||
```python
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Precise Assistant",
|
||||
instructions="Provide accurate, concise responses.",
|
||||
model_config={
|
||||
"model": "gpt-4o",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 150,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the foundational examples:
|
||||
|
||||
- [`45-openai-agent-basic.py`](../examples/foundational/45-openai-agent-basic.py) - Basic agent with tools
|
||||
- [`46-openai-agent-handoffs.py`](../examples/foundational/46-openai-agent-handoffs.py) - Multi-agent system with handoffs
|
||||
|
||||
## Methods
|
||||
|
||||
### Core Methods
|
||||
|
||||
- `update_agent_config()` - Update instructions and model settings
|
||||
- `add_tool()` - Add new tools dynamically
|
||||
- `add_handoff_agent()` - Add handoff destinations
|
||||
- `get_session_context()` - Get current session state
|
||||
- `update_session_context()` - Update session variables
|
||||
|
||||
### Lifecycle Methods
|
||||
|
||||
Inherited from `AIService`:
|
||||
- `start()` - Initialize the agent
|
||||
- `stop()` - Clean up resources
|
||||
- `cancel()` - Cancel ongoing operations
|
||||
|
||||
## Integration with Pipecat
|
||||
|
||||
The service processes `TextFrame` inputs and generates:
|
||||
- `LLMFullResponseStartFrame` - Response beginning
|
||||
- `LLMTextFrame` - Streaming text tokens (if streaming enabled)
|
||||
- `LLMFullResponseEndFrame` - Response completion
|
||||
|
||||
This integrates seamlessly with Pipecat's conversation pipeline and context aggregators.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The service includes robust error handling for:
|
||||
- Missing API keys or SDK installation
|
||||
- Agent processing failures
|
||||
- Network connectivity issues
|
||||
- Malformed tool responses
|
||||
|
||||
Errors are emitted as `ErrorFrame` objects in the pipeline.
|
||||
|
||||
## Requirements
|
||||
|
||||
- OpenAI API key
|
||||
- `openai-agents` package
|
||||
- Python 3.10+
|
||||
|
||||
## Limitations
|
||||
|
||||
- Currently supports OpenAI models only (via Agents SDK)
|
||||
- Handoffs work within individual requests (no cross-request state)
|
||||
- Real-time voice features require additional setup
|
||||
11
src/pipecat/services/openai_agent/__init__.py
Normal file
11
src/pipecat/services/openai_agent/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Agents SDK service for Pipecat integration."""
|
||||
|
||||
from .agent_service import OpenAIAgentService
|
||||
|
||||
__all__ = ["OpenAIAgentService"]
|
||||
567
src/pipecat/services/openai_agent/agent_service.py
Normal file
567
src/pipecat/services/openai_agent/agent_service.py
Normal file
@@ -0,0 +1,567 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Agents SDK integration service.
|
||||
|
||||
Provides integration with the OpenAI Agents SDK for building AI applications
|
||||
within Pipecat pipelines. This service allows leveraging agent loops, handoffs,
|
||||
guardrails, sessions, and tools from the OpenAI Agents SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Union,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from agents import Agent, InputGuardrail, OutputGuardrail, Runner, Tool
|
||||
from agents.result import RunResult, RunResultStreaming
|
||||
from agents.stream_events import StreamEvent
|
||||
except ImportError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI Agents SDK, you need to `pip install openai-agents`. "
|
||||
"Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolLike(Protocol):
|
||||
"""Protocol for tool-like objects."""
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Tool call interface."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentLike(Protocol):
|
||||
"""Protocol for agent-like objects."""
|
||||
|
||||
name: str
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Agent call interface."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIAgentContextAggregatorPair:
|
||||
"""Pair of OpenAI Agent context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
_user: "OpenAIAgentUserContextAggregator"
|
||||
_assistant: "OpenAIAgentAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIAgentUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAgentAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAIAgentService(AIService):
|
||||
"""OpenAI Agents SDK service for Pipecat.
|
||||
|
||||
Integrates the OpenAI Agents SDK with Pipecat's pipeline architecture,
|
||||
enabling advanced agentic workflows with features like handoffs, guardrails,
|
||||
sessions, and tools within real-time conversational AI applications.
|
||||
|
||||
The service processes text input frames and generates streaming responses
|
||||
using the agent's configured capabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Optional[Agent] = None,
|
||||
name: str = "Assistant",
|
||||
instructions: Union[str, Sequence[str]] = "You are a helpful assistant.",
|
||||
handoffs: Optional[Sequence[AgentLike]] = None,
|
||||
tools: Optional[Sequence[ToolLike]] = None,
|
||||
input_guardrails: Optional[Sequence[InputGuardrail]] = None,
|
||||
output_guardrails: Optional[Sequence[OutputGuardrail]] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None,
|
||||
session_config: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
streaming: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI Agent service.
|
||||
|
||||
Args:
|
||||
agent: Pre-configured Agent instance. If provided, other agent configuration
|
||||
parameters will be ignored.
|
||||
name: Name of the agent for identification and handoffs.
|
||||
instructions: System instructions that define the agent's behavior.
|
||||
handoffs: List of other agents this agent can hand off to.
|
||||
tools: List of callable functions the agent can use as tools.
|
||||
input_guardrails: List of input validation guardrails.
|
||||
output_guardrails: List of output validation guardrails.
|
||||
model_config: Configuration for the underlying language model.
|
||||
session_config: Configuration for session management.
|
||||
api_key: OpenAI API key. If not provided, will use OPENAI_API_KEY env var.
|
||||
streaming: Whether to use streaming responses for real-time output.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set up API key
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
elif not os.getenv("OPENAI_API_KEY"):
|
||||
logger.warning("No OpenAI API key provided. Set OPENAI_API_KEY environment variable.")
|
||||
|
||||
# Create or use existing agent
|
||||
if agent:
|
||||
self._agent = agent
|
||||
else:
|
||||
# Convert sequences to lists and handle string instructions
|
||||
agent_handoffs: List[Any] = list(handoffs) if handoffs else []
|
||||
agent_tools: List[Any] = list(tools) if tools else []
|
||||
agent_input_guardrails: List[Any] = list(input_guardrails) if input_guardrails else []
|
||||
agent_output_guardrails: List[Any] = (
|
||||
list(output_guardrails) if output_guardrails else []
|
||||
)
|
||||
|
||||
# Handle instructions - convert sequence to string if needed
|
||||
if isinstance(instructions, str):
|
||||
agent_instructions = instructions
|
||||
else:
|
||||
agent_instructions = " ".join(str(instr) for instr in instructions)
|
||||
|
||||
self._agent = Agent(
|
||||
name=name,
|
||||
instructions=agent_instructions,
|
||||
handoffs=agent_handoffs,
|
||||
tools=agent_tools,
|
||||
input_guardrails=agent_input_guardrails,
|
||||
output_guardrails=agent_output_guardrails,
|
||||
model=model_config.get("model", "gpt-4o") if model_config else "gpt-4o",
|
||||
)
|
||||
|
||||
self._streaming = streaming
|
||||
self._session_config = session_config or {}
|
||||
self._current_session = None
|
||||
self._accumulated_text = ""
|
||||
|
||||
# Set model name for metrics
|
||||
if model_config and "model" in model_config:
|
||||
self.set_model_name(model_config["model"])
|
||||
else:
|
||||
self.set_model_name("gpt-4o") # Default model
|
||||
|
||||
logger.info(f"Initialized OpenAI Agent service: {self._agent.name}")
|
||||
|
||||
@property
|
||||
def agent(self) -> Agent:
|
||||
"""Get the underlying OpenAI Agent.
|
||||
|
||||
Returns:
|
||||
The configured Agent instance.
|
||||
"""
|
||||
return self._agent
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIAgentContextAggregatorPair:
|
||||
"""Create OpenAI-specific context aggregators for agent interactions.
|
||||
|
||||
Creates a pair of context aggregators optimized for OpenAI Agent interactions,
|
||||
including support for function calls, tool usage, and conversation management.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
OpenAIAgentContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIAgentContextAggregatorPair.
|
||||
"""
|
||||
user = OpenAIAgentUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAgentAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIAgentContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def update_agent_config(
|
||||
self,
|
||||
*,
|
||||
instructions: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Update agent configuration dynamically.
|
||||
|
||||
Args:
|
||||
instructions: New system instructions for the agent.
|
||||
model_config: Updated model configuration.
|
||||
**kwargs: Additional agent configuration parameters.
|
||||
"""
|
||||
if instructions:
|
||||
self._agent.instructions = instructions
|
||||
logger.info(f"Updated agent instructions for {self._agent.name}")
|
||||
|
||||
if model_config:
|
||||
# Note: OpenAI Agents SDK handles model configuration during agent creation
|
||||
# We can't update model_config after agent is created, but we can update our model name
|
||||
if "model" in model_config:
|
||||
self.set_model_name(model_config["model"])
|
||||
logger.info(f"Updated model config for {self._agent.name}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the OpenAI Agent service.
|
||||
|
||||
Initializes the agent session and prepares for processing.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
logger.info(f"Starting OpenAI Agent service: {self._agent.name}")
|
||||
await super().start(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the OpenAI Agent service.
|
||||
|
||||
Cleans up resources and ends the current session.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
logger.info(f"Stopping OpenAI Agent service: {self._agent.name}")
|
||||
await super().stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the OpenAI Agent service.
|
||||
|
||||
Cancels any ongoing operations.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
logger.info(f"Cancelling OpenAI Agent service: {self._agent.name}")
|
||||
await super().cancel(frame)
|
||||
|
||||
@override
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
"""Process frames and handle agent interactions.
|
||||
|
||||
Processes OpenAILLMContextFrame and TextFrame by running them through the OpenAI Agent
|
||||
and streams the results back as LLM frames.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
# Process context frame through the agent
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
# Extract the latest user message from the context
|
||||
messages = frame.context.get_messages()
|
||||
if messages:
|
||||
# Get the last user message
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Extract text from content array
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
user_input = " ".join(text_parts)
|
||||
else:
|
||||
user_input = str(content)
|
||||
|
||||
if user_input.strip():
|
||||
await self._process_agent_request(user_input)
|
||||
break
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent context: {e}")
|
||||
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
|
||||
elif isinstance(frame, TextFrame):
|
||||
# Process text input through the agent directly (for backwards compatibility)
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self._process_agent_request(frame.text)
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent request: {e}")
|
||||
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
|
||||
else:
|
||||
# For frames we don't handle, pass them through with direction
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_agent_request(self, input_text: str):
|
||||
"""Process an agent request and stream the results.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
logger.debug(f"Processing agent request: {input_text}")
|
||||
|
||||
if self._streaming:
|
||||
await self._process_streaming_response(input_text)
|
||||
else:
|
||||
await self._process_non_streaming_response(input_text)
|
||||
|
||||
async def _process_streaming_response(self, input_text: str):
|
||||
"""Process a streaming agent response.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
try:
|
||||
# Run the agent with streaming
|
||||
result: RunResultStreaming = Runner.run_streamed(
|
||||
self._agent, input_text, context=self._session_config
|
||||
)
|
||||
|
||||
has_streaming_deltas = False
|
||||
|
||||
# Process the stream events
|
||||
async for event in result.stream_events():
|
||||
if event.type == "raw_response_event":
|
||||
# Handle token-by-token streaming
|
||||
# Only check for delta on events that are known to have it
|
||||
if hasattr(event.data, "delta") and getattr(event.data, "delta", None):
|
||||
delta_text = getattr(event.data, "delta", "")
|
||||
if delta_text:
|
||||
has_streaming_deltas = True
|
||||
self._accumulated_text += delta_text
|
||||
await self.push_frame(LLMTextFrame(text=delta_text))
|
||||
|
||||
elif event.type == "run_item_stream_event":
|
||||
# Handle completed items
|
||||
if event.item.type == "message_output_item":
|
||||
# Only process complete message if we didn't get streaming deltas
|
||||
if not has_streaming_deltas:
|
||||
message_text = self._extract_message_text(event.item)
|
||||
logger.debug(
|
||||
f"Processing complete message (no deltas): {message_text[:50]}..."
|
||||
if len(message_text) > 50
|
||||
else f"Processing complete message: {message_text}"
|
||||
)
|
||||
if message_text:
|
||||
await self.push_frame(LLMTextFrame(text=message_text))
|
||||
|
||||
elif event.item.type == "tool_call_item":
|
||||
# Use getattr for safe attribute access
|
||||
tool_name = getattr(event.item, "tool_name", "unknown")
|
||||
logger.debug(f"Tool called: {tool_name}")
|
||||
|
||||
elif event.item.type == "tool_call_output_item":
|
||||
output = getattr(event.item, "output", "no output")
|
||||
logger.debug(f"Tool output: {output}")
|
||||
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
logger.debug(f"Agent updated: {event.new_agent.name}")
|
||||
|
||||
# Reset accumulated text for next request
|
||||
self._accumulated_text = ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming response: {e}")
|
||||
raise
|
||||
|
||||
async def _process_non_streaming_response(self, input_text: str):
|
||||
"""Process a non-streaming agent response.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
try:
|
||||
# Run the agent without streaming
|
||||
result: RunResult = await Runner.run(
|
||||
self._agent, input_text, context=self._session_config
|
||||
)
|
||||
|
||||
# Send the final output
|
||||
if result.final_output:
|
||||
await self.push_frame(LLMTextFrame(text=result.final_output))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in non-streaming response: {e}")
|
||||
raise
|
||||
|
||||
def _extract_message_text(self, item) -> str:
|
||||
"""Extract text from a message output item.
|
||||
|
||||
Args:
|
||||
item: The message output item from the agent.
|
||||
|
||||
Returns:
|
||||
The extracted text content.
|
||||
"""
|
||||
try:
|
||||
# Handle OpenAI Agents SDK MessageOutputItem format
|
||||
if hasattr(item, "raw_item") and hasattr(item.raw_item, "content"):
|
||||
content = item.raw_item.content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for content_part in content:
|
||||
if hasattr(content_part, "text"):
|
||||
text_parts.append(content_part.text)
|
||||
elif (
|
||||
isinstance(content_part, dict)
|
||||
and content_part.get("type") == "output_text"
|
||||
):
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
elif isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
# Handle direct content attribute
|
||||
elif hasattr(item, "content"):
|
||||
if isinstance(item.content, str):
|
||||
return item.content
|
||||
elif isinstance(item.content, list):
|
||||
# Extract text from content array
|
||||
text_parts = []
|
||||
for content_part in item.content:
|
||||
if isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
elif isinstance(content_part, str):
|
||||
text_parts.append(content_part)
|
||||
return "".join(text_parts)
|
||||
|
||||
# If no text content found, return empty string instead of str(item)
|
||||
logger.debug(f"No extractable text content found in item: {type(item)}")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from message item: {e}")
|
||||
return ""
|
||||
|
||||
async def add_tool(self, tool_function: ToolLike):
|
||||
"""Add a tool function to the agent.
|
||||
|
||||
Args:
|
||||
tool_function: A callable function or Tool object to add as a tool.
|
||||
"""
|
||||
if hasattr(self._agent, "tools"):
|
||||
# Cast to Any to handle the type variance issue
|
||||
tools_list: List[Any] = self._agent.tools
|
||||
tools_list.append(tool_function)
|
||||
tool_name = getattr(
|
||||
tool_function, "__name__", getattr(tool_function, "name", "unknown")
|
||||
)
|
||||
logger.info(f"Added tool {tool_name} to agent {self._agent.name}")
|
||||
|
||||
async def add_handoff_agent(self, agent: AgentLike):
|
||||
"""Add a handoff agent.
|
||||
|
||||
Args:
|
||||
agent: Another Agent instance or handoff object that this agent can hand off to.
|
||||
"""
|
||||
if hasattr(self._agent, "handoffs"):
|
||||
# Cast to Any to handle the type variance issue
|
||||
handoffs_list: List[Any] = self._agent.handoffs
|
||||
handoffs_list.append(agent)
|
||||
agent_name = getattr(agent, "name", "unknown")
|
||||
logger.info(f"Added handoff agent {agent_name} to agent {self._agent.name}")
|
||||
|
||||
def get_session_context(self) -> Dict[str, Any]:
|
||||
"""Get the current session context.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current session context.
|
||||
"""
|
||||
return self._session_config.copy()
|
||||
|
||||
def update_session_context(self, context: Dict[str, Any]):
|
||||
"""Update the session context.
|
||||
|
||||
Args:
|
||||
context: Dictionary of context updates to apply.
|
||||
"""
|
||||
self._session_config.update(context)
|
||||
logger.debug(f"Updated session context for agent {self._agent.name}")
|
||||
|
||||
|
||||
class OpenAIAgentUserContextAggregator(LLMUserContextAggregator):
|
||||
"""OpenAI Agent-specific user context aggregator.
|
||||
|
||||
Handles aggregation of user messages for OpenAI Agent services.
|
||||
Inherits all functionality from the base LLMUserContextAggregator.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAgentAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""OpenAI Agent-specific assistant context aggregator.
|
||||
|
||||
Handles aggregation of assistant messages for OpenAI Agent services,
|
||||
with specialized support for OpenAI's function calling format,
|
||||
tool usage tracking, and agent interaction management.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -25,7 +25,6 @@ from pydantic import BaseModel
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
ControlFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -42,7 +41,6 @@ from pipecat.frames.frames import (
|
||||
UserAudioRawFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
DataFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -107,17 +105,6 @@ class DailyInputTransportMessageUrgentFrame(InputTransportMessageUrgentFrame):
|
||||
participant_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyUpdateRemoteParticipantsFrame(ControlFrame):
|
||||
"""Frame to update remote participants in Daily calls.
|
||||
|
||||
Parameters:
|
||||
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
|
||||
"""
|
||||
|
||||
remote_participants: Mapping[str, Any] = None
|
||||
|
||||
|
||||
class WebRTCVADAnalyzer(VADAnalyzer):
|
||||
"""Voice Activity Detection analyzer using WebRTC.
|
||||
|
||||
@@ -228,7 +215,6 @@ class DailyCallbacks(BaseModel):
|
||||
on_active_speaker_changed: Called when the active speaker of the call has changed.
|
||||
on_joined: Called when bot successfully joined a room.
|
||||
on_left: Called when bot left a room.
|
||||
on_before_leave: Called when bot is about to leave the room.
|
||||
on_error: Called when an error occurs.
|
||||
on_app_message: Called when receiving an app message.
|
||||
on_call_state_updated: Called when call state changes.
|
||||
@@ -258,7 +244,6 @@ class DailyCallbacks(BaseModel):
|
||||
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_left: Callable[[], Awaitable[None]]
|
||||
on_before_leave: Callable[[], Awaitable[None]]
|
||||
on_error: Callable[[str], Awaitable[None]]
|
||||
on_app_message: Callable[[Any, str], Awaitable[None]]
|
||||
on_call_state_updated: Callable[[str], Awaitable[None]]
|
||||
@@ -374,7 +359,6 @@ class DailyTransportClient(EventHandler):
|
||||
self._transcription_ids = []
|
||||
self._transcription_status = None
|
||||
self._dial_out_session_id: str = ""
|
||||
self._dial_in_session_id: str = ""
|
||||
|
||||
self._joining = False
|
||||
self._joined = False
|
||||
@@ -735,9 +719,6 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
logger.info(f"Leaving {self._room_url}")
|
||||
|
||||
# Call callback before leaving.
|
||||
await self._callbacks.on_before_leave()
|
||||
|
||||
if self._params.transcription_enabled:
|
||||
await self.stop_transcription()
|
||||
|
||||
@@ -842,16 +823,6 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
settings: SIP call transfer settings.
|
||||
"""
|
||||
session_id = (
|
||||
settings.get("sessionId") or self._dial_out_session_id or self._dial_in_session_id
|
||||
)
|
||||
if not session_id:
|
||||
logger.error("Unable to transfer SIP call: 'sessionId' is not set")
|
||||
return
|
||||
|
||||
# Update 'sessionId' field.
|
||||
settings["sessionId"] = session_id
|
||||
|
||||
future = self._get_event_loop().create_future()
|
||||
self._client.sip_call_transfer(settings, completion=completion_callback(future))
|
||||
await future
|
||||
@@ -1170,7 +1141,6 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
data: Dial-in connection data.
|
||||
"""
|
||||
self._dial_in_session_id = data["sessionId"] if "sessionId" in data else ""
|
||||
self._call_event_callback(self._callbacks.on_dialin_connected, data)
|
||||
|
||||
def on_dialin_ready(self, sip_endpoint: str):
|
||||
@@ -1187,9 +1157,6 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
data: Dial-in stop data.
|
||||
"""
|
||||
# Cleanup only if our session stopped.
|
||||
if data.get("sessionId") == self._dial_in_session_id:
|
||||
self._dial_in_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialin_stopped, data)
|
||||
|
||||
def on_dialin_error(self, data: Any):
|
||||
@@ -1198,9 +1165,6 @@ class DailyTransportClient(EventHandler):
|
||||
Args:
|
||||
data: Dial-in error data.
|
||||
"""
|
||||
# Cleanup only if our session errored out.
|
||||
if data.get("sessionId") == self._dial_in_session_id:
|
||||
self._dial_in_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialin_error, data)
|
||||
|
||||
def on_dialin_warning(self, data: Any):
|
||||
@@ -1235,7 +1199,7 @@ class DailyTransportClient(EventHandler):
|
||||
data: Dial-out stop data.
|
||||
"""
|
||||
# Cleanup only if our session stopped.
|
||||
if data.get("sessionId") == self._dial_out_session_id:
|
||||
if data["sessionId"] == self._dial_out_session_id:
|
||||
self._dial_out_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialout_stopped, data)
|
||||
|
||||
@@ -1246,7 +1210,7 @@ class DailyTransportClient(EventHandler):
|
||||
data: Dial-out error data.
|
||||
"""
|
||||
# Cleanup only if our session errored out.
|
||||
if data.get("sessionId") == self._dial_out_session_id:
|
||||
if data["sessionId"] == self._dial_out_session_id:
|
||||
self._dial_out_session_id = ""
|
||||
self._call_event_callback(self._callbacks.on_dialout_error, data)
|
||||
|
||||
@@ -1803,31 +1767,6 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process outgoing frames, including transport messages.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
logger.debug(f"Got a DailyUpdateRemoteParticipantsFrame: {frame}")
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process outgoing frames, including transport messages.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
"""Send a transport message to participants.
|
||||
|
||||
@@ -1923,7 +1862,6 @@ class DailyTransport(BaseTransport):
|
||||
on_active_speaker_changed=self._on_active_speaker_changed,
|
||||
on_joined=self._on_joined,
|
||||
on_left=self._on_left,
|
||||
on_before_leave=self._on_before_leave,
|
||||
on_error=self._on_error,
|
||||
on_app_message=self._on_app_message,
|
||||
on_call_state_updated=self._on_call_state_updated,
|
||||
@@ -1987,10 +1925,6 @@ class DailyTransport(BaseTransport):
|
||||
self._register_event_handler("on_recording_started")
|
||||
self._register_event_handler("on_recording_stopped")
|
||||
self._register_event_handler("on_recording_error")
|
||||
self._register_event_handler("on_before_disconnect", sync=True)
|
||||
# Deprecated
|
||||
self._register_event_handler("on_joined")
|
||||
self._register_event_handler("on_left")
|
||||
|
||||
#
|
||||
# BaseTransport
|
||||
@@ -2242,10 +2176,6 @@ class DailyTransport(BaseTransport):
|
||||
"""Handle room left events."""
|
||||
await self._call_event_handler("on_left")
|
||||
|
||||
async def _on_before_leave(self):
|
||||
"""Handle before leave room events."""
|
||||
await self._call_event_handler("on_before_disconnect")
|
||||
|
||||
async def _on_error(self, error):
|
||||
"""Handle error events and push error frames."""
|
||||
await self._call_event_handler("on_error", error)
|
||||
@@ -2385,7 +2315,7 @@ class DailyTransport(BaseTransport):
|
||||
"""Handle participant updated events."""
|
||||
await self._call_event_handler("on_participant_updated", participant)
|
||||
|
||||
async def _on_transcription_message(self, message: Mapping[str, Any]) -> None:
|
||||
async def _on_transcription_message(self, message: Dict[str, Any]) -> None:
|
||||
"""Handle transcription message events."""
|
||||
await self._call_event_handler("on_transcription_message", message)
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ class LiveKitCallbacks(BaseModel):
|
||||
|
||||
on_connected: Callable[[], Awaitable[None]]
|
||||
on_disconnected: Callable[[], Awaitable[None]]
|
||||
on_before_disconnect: Callable[[], Awaitable[None]]
|
||||
on_participant_connected: Callable[[str], Awaitable[None]]
|
||||
on_participant_disconnected: Callable[[str], Awaitable[None]]
|
||||
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
|
||||
@@ -283,7 +282,6 @@ class LiveKitTransportClient:
|
||||
return
|
||||
|
||||
logger.info(f"Disconnecting from {self._room_name}")
|
||||
await self._callbacks.on_before_disconnect()
|
||||
await self.room.disconnect()
|
||||
self._connected = False
|
||||
logger.info(f"Disconnected from {self._room_name}")
|
||||
@@ -920,7 +918,6 @@ class LiveKitTransport(BaseTransport):
|
||||
callbacks = LiveKitCallbacks(
|
||||
on_connected=self._on_connected,
|
||||
on_disconnected=self._on_disconnected,
|
||||
on_before_disconnect=self._on_before_disconnect,
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
on_audio_track_subscribed=self._on_audio_track_subscribed,
|
||||
@@ -950,7 +947,6 @@ class LiveKitTransport(BaseTransport):
|
||||
self._register_event_handler("on_first_participant_joined")
|
||||
self._register_event_handler("on_participant_left")
|
||||
self._register_event_handler("on_call_state_updated")
|
||||
self._register_event_handler("on_before_disconnect", sync=True)
|
||||
|
||||
def input(self) -> LiveKitInputTransport:
|
||||
"""Get the input transport for receiving media and events.
|
||||
@@ -1045,10 +1041,6 @@ class LiveKitTransport(BaseTransport):
|
||||
"""Handle room disconnected events."""
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _on_before_disconnect(self):
|
||||
"""Handle before disconnection room events."""
|
||||
await self._call_event_handler("on_before_disconnect")
|
||||
|
||||
async def _on_participant_connected(self, participant_id: str):
|
||||
"""Handle participant connected events."""
|
||||
await self._call_event_handler("on_participant_connected", participant_id)
|
||||
|
||||
@@ -95,20 +95,15 @@ class SmallWebRTCTrack:
|
||||
enable/disable control and frame discarding for audio and video streams.
|
||||
"""
|
||||
|
||||
def __init__(self, receiver):
|
||||
def __init__(self, track: MediaStreamTrack):
|
||||
"""Initialize the WebRTC track wrapper.
|
||||
|
||||
Args:
|
||||
receiver: The RemoteStreamTrack receiver instance.
|
||||
track: The underlying MediaStreamTrack to wrap.
|
||||
index: The index of the track in the transceiver (0 for mic, 1 for cam, 2 for screen)
|
||||
"""
|
||||
self._receiver = receiver
|
||||
# Configuring the receiver for not consuming the track by default to prevent memory grow
|
||||
self._receiver._enabled = False
|
||||
self._track = receiver.track
|
||||
self._track = track
|
||||
self._enabled = True
|
||||
self._last_recv_time: float = 0.0
|
||||
self._idle_task: Optional[asyncio.Task] = None
|
||||
self._idle_timeout: float = 2.0 # seconds before discarding old frames
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
"""Enable or disable the track.
|
||||
@@ -143,44 +138,13 @@ class SmallWebRTCTrack:
|
||||
async def recv(self) -> Optional[Frame]:
|
||||
"""Receive the next frame from the track.
|
||||
|
||||
Enables the internal receiving state and starts idle watcher.
|
||||
|
||||
Returns:
|
||||
The next frame, except for video tracks, where it returns the frame only if the track is enabled, otherwise, returns None.
|
||||
"""
|
||||
self._receiver._enabled = True
|
||||
self._last_recv_time = time.time()
|
||||
|
||||
# start idle watcher if not already running
|
||||
if not self._idle_task or self._idle_task.done():
|
||||
self._idle_task = asyncio.create_task(self._idle_watcher())
|
||||
|
||||
if not self._enabled and self._track.kind == "video":
|
||||
return None
|
||||
return await self._track.recv()
|
||||
|
||||
async def _idle_watcher(self):
|
||||
"""Disable receiving if idle for more than _idle_timeout and monitor queue size."""
|
||||
while self._receiver._enabled:
|
||||
await asyncio.sleep(self._idle_timeout)
|
||||
idle_duration = time.time() - self._last_recv_time
|
||||
if idle_duration >= self._idle_timeout:
|
||||
# discard old frames to prevent memory growth
|
||||
logger.debug(
|
||||
f"Disabling receiver for {self._track.kind} track after {idle_duration:.2f}s idle"
|
||||
)
|
||||
await self.discard_old_frames()
|
||||
self._receiver._enabled = False
|
||||
|
||||
def stop(self):
|
||||
"""Stop receiving frames from the track."""
|
||||
self._receiver._enabled = False
|
||||
if self._idle_task:
|
||||
self._idle_task.cancel()
|
||||
self._idle_task = None
|
||||
if self._track:
|
||||
self._track.stop()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward attribute access to the underlying track.
|
||||
|
||||
@@ -490,10 +454,6 @@ class SmallWebRTCConnection(BaseObject):
|
||||
|
||||
async def _close(self):
|
||||
"""Close the peer connection and cleanup resources."""
|
||||
for track in self._track_map.values():
|
||||
if track:
|
||||
track.stop()
|
||||
self._track_map.clear()
|
||||
if self._pc:
|
||||
await self._pc.close()
|
||||
self._message_queue.clear()
|
||||
@@ -566,8 +526,8 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.warning("No audio transceiver is available")
|
||||
return None
|
||||
|
||||
receiver = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver
|
||||
audio_track = SmallWebRTCTrack(receiver) if receiver else None
|
||||
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
|
||||
audio_track = SmallWebRTCTrack(track) if track else None
|
||||
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
|
||||
return audio_track
|
||||
|
||||
@@ -588,8 +548,8 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.warning("No video transceiver is available")
|
||||
return None
|
||||
|
||||
receiver = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver
|
||||
video_track = SmallWebRTCTrack(receiver) if receiver else None
|
||||
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
|
||||
video_track = SmallWebRTCTrack(track) if track else None
|
||||
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
|
||||
return video_track
|
||||
|
||||
@@ -610,8 +570,8 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.warning("No screen video transceiver is available")
|
||||
return None
|
||||
|
||||
receiver = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver
|
||||
video_track = SmallWebRTCTrack(receiver) if receiver else None
|
||||
track = transceivers[SCREEN_VIDEO_TRANSCEIVER_INDEX].receiver.track
|
||||
video_track = SmallWebRTCTrack(track) if track else None
|
||||
self._track_map[SCREEN_VIDEO_TRANSCEIVER_INDEX] = video_track
|
||||
return video_track
|
||||
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""SmallWebRTC request handler for managing peer connections.
|
||||
|
||||
This module provides a client for handling web requests and managing WebRTC connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmallWebRTCRequest:
|
||||
"""Small WebRTC transport session arguments for the runner.
|
||||
|
||||
Parameters:
|
||||
sdp: The SDP string (Session Description Protocol).
|
||||
type: The type of the SDP, either "offer" or "answer".
|
||||
pc_id: Optional identifier for the peer connection.
|
||||
restart_pc: Optional whether to restart the peer connection.
|
||||
request_data: Optional custom data sent by the customer.
|
||||
"""
|
||||
|
||||
sdp: str
|
||||
type: str
|
||||
pc_id: Optional[str] = None
|
||||
restart_pc: Optional[bool] = None
|
||||
request_data: Optional[Any] = None
|
||||
|
||||
|
||||
class ConnectionMode(Enum):
|
||||
"""Enum defining the connection handling modes."""
|
||||
|
||||
SINGLE = "single" # Only one active connection allowed
|
||||
MULTIPLE = "multiple" # Multiple simultaneous connections allowed
|
||||
|
||||
|
||||
class SmallWebRTCRequestHandler:
|
||||
"""SmallWebRTC request handler for managing peer connections.
|
||||
|
||||
This class is responsible for:
|
||||
- Handling incoming SmallWebRTC requests.
|
||||
- Creating and managing WebRTC peer connections.
|
||||
- Supporting ESP32-specific SDP munging if enabled.
|
||||
- Invoking callbacks for newly initialized connections.
|
||||
- Supporting both single and multiple connection modes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ice_servers: Optional[List[IceServer]] = None,
|
||||
esp32_mode: bool = False,
|
||||
host: Optional[str] = None,
|
||||
connection_mode: ConnectionMode = ConnectionMode.MULTIPLE,
|
||||
) -> None:
|
||||
"""Initialize a SmallWebRTC request handler.
|
||||
|
||||
Args:
|
||||
ice_servers (Optional[List[IceServer]]): List of ICE servers to use for WebRTC
|
||||
connections.
|
||||
esp32_mode (bool): If True, enables ESP32-specific SDP munging.
|
||||
host (Optional[str]): Host address used for SDP munging in ESP32 mode.
|
||||
Ignored if `esp32_mode` is False.
|
||||
connection_mode (ConnectionMode): Mode of operation for handling connections.
|
||||
SINGLE allows only one active connection, MULTIPLE allows several.
|
||||
"""
|
||||
self._ice_servers = ice_servers
|
||||
self._esp32_mode = esp32_mode
|
||||
self._host = host
|
||||
self._connection_mode = connection_mode
|
||||
|
||||
# Store connections by pc_id
|
||||
self._pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
def _check_single_connection_constraints(self, pc_id: Optional[str]) -> None:
|
||||
"""Check if the connection request satisfies single connection mode constraints.
|
||||
|
||||
Args:
|
||||
pc_id: The peer connection ID from the request
|
||||
|
||||
Raises:
|
||||
HTTPException: If constraints are violated in single connection mode
|
||||
"""
|
||||
if self._connection_mode != ConnectionMode.SINGLE:
|
||||
return
|
||||
|
||||
if not self._pcs_map: # No existing connections
|
||||
return
|
||||
|
||||
# Get the existing connection (should be only one in single mode)
|
||||
existing_connection = next(iter(self._pcs_map.values()))
|
||||
|
||||
if existing_connection.pc_id != pc_id and pc_id:
|
||||
logger.warning(
|
||||
f"Connection pc_id mismatch: existing={existing_connection.pc_id}, received={pc_id}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="PC ID mismatch with existing connection")
|
||||
|
||||
if not pc_id:
|
||||
logger.warning(
|
||||
"Cannot create new connection: existing connection found but no pc_id received"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot create new connection with existing connection active",
|
||||
)
|
||||
|
||||
async def handle_web_request(
|
||||
self,
|
||||
request: SmallWebRTCRequest,
|
||||
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Handle a SmallWebRTC request and resolve the pending answer.
|
||||
|
||||
This method will:
|
||||
- Reuse an existing WebRTC connection if `pc_id` exists.
|
||||
- Otherwise, create a new `SmallWebRTCConnection`.
|
||||
- Invoke the provided callback with the connection.
|
||||
- Manage ESP32-specific munging if enabled.
|
||||
- Enforce single/multiple connection mode constraints.
|
||||
|
||||
Args:
|
||||
request (SmallWebRTCRequest): The incoming WebRTC request, containing
|
||||
SDP, type, and optionally a `pc_id`.
|
||||
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
|
||||
asynchronous callback function that is invoked with the WebRTC connection.
|
||||
|
||||
Raises:
|
||||
HTTPException: If connection mode constraints are violated
|
||||
Exception: Any exception raised during request handling or callback execution
|
||||
will be logged and propagated.
|
||||
"""
|
||||
try:
|
||||
pc_id = request.pc_id
|
||||
|
||||
# Check connection mode constraints first
|
||||
self._check_single_connection_constraints(pc_id)
|
||||
|
||||
# After constraints are satisfied, get the existing connection if any
|
||||
existing_connection = self._pcs_map.get(pc_id) if pc_id else None
|
||||
|
||||
if existing_connection:
|
||||
pipecat_connection = existing_connection
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
await pipecat_connection.renegotiate(
|
||||
sdp=request.sdp,
|
||||
type=request.type,
|
||||
restart_pc=request.restart_pc or False,
|
||||
)
|
||||
else:
|
||||
pipecat_connection = SmallWebRTCConnection(ice_servers=self._ice_servers)
|
||||
await pipecat_connection.initialize(sdp=request.sdp, type=request.type)
|
||||
|
||||
@pipecat_connection.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
|
||||
self._pcs_map.pop(webrtc_connection.pc_id, None)
|
||||
|
||||
# Invoke callback provided in runner arguments
|
||||
try:
|
||||
await webrtc_connection_callback(pipecat_connection)
|
||||
logger.debug(
|
||||
f"webrtc_connection_callback executed successfully for peer: {pipecat_connection.pc_id}"
|
||||
)
|
||||
except Exception as callback_error:
|
||||
logger.error(
|
||||
f"webrtc_connection_callback failed for peer {pipecat_connection.pc_id}: {callback_error}"
|
||||
)
|
||||
|
||||
answer = pipecat_connection.get_answer()
|
||||
|
||||
if self._esp32_mode and self._host and self._host != "localhost":
|
||||
from pipecat.runner.utils import smallwebrtc_sdp_munging
|
||||
|
||||
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], self._host)
|
||||
|
||||
self._pcs_map[answer["pc_id"]] = pipecat_connection
|
||||
|
||||
return answer
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SmallWebRTC request: {e}")
|
||||
logger.debug(f"SmallWebRTC request details: {request}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Clear the connection map."""
|
||||
coros = [pc.disconnect() for pc in self._pcs_map.values()]
|
||||
await asyncio.gather(*coros)
|
||||
self._pcs_map.clear()
|
||||
@@ -14,33 +14,13 @@ and async cleanup for all Pipecat components.
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandler:
|
||||
"""Data class to store event handlers information.
|
||||
|
||||
This data class stores the event name, a list of handlers to run for this
|
||||
event, and whether these handlers will be executed in a task.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the event handler.
|
||||
handlers (List[Any]): A list of functions to be called when this event is triggered.
|
||||
is_sync (bool): Indicates whether the functions are executed in a task.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
handlers: List[Any]
|
||||
is_sync: bool
|
||||
|
||||
|
||||
class BaseObject(ABC):
|
||||
"""Abstract base class providing common functionality for Pipecat objects.
|
||||
|
||||
@@ -61,7 +41,7 @@ class BaseObject(ABC):
|
||||
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
|
||||
# Registered event handlers.
|
||||
self._event_handlers: Dict[str, EventHandler] = {}
|
||||
self._event_handlers: dict = {}
|
||||
|
||||
# Set of tasks being executed. When a task finishes running it gets
|
||||
# automatically removed from the set. When we cleanup we wait for all
|
||||
@@ -123,21 +103,18 @@ class BaseObject(ABC):
|
||||
Can be sync or async.
|
||||
"""
|
||||
if event_name in self._event_handlers:
|
||||
self._event_handlers[event_name].handlers.append(handler)
|
||||
self._event_handlers[event_name].append(handler)
|
||||
else:
|
||||
logger.warning(f"Event handler {event_name} not registered")
|
||||
|
||||
def _register_event_handler(self, event_name: str, sync: bool = False):
|
||||
def _register_event_handler(self, event_name: str):
|
||||
"""Register an event handler type.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event type to register.
|
||||
sync: Whether this event handler will be executed in a task.
|
||||
"""
|
||||
if event_name not in self._event_handlers:
|
||||
self._event_handlers[event_name] = EventHandler(
|
||||
name=event_name, handlers=[], is_sync=sync
|
||||
)
|
||||
self._event_handlers[event_name] = []
|
||||
else:
|
||||
logger.warning(f"Event handler {event_name} not registered")
|
||||
|
||||
@@ -149,43 +126,34 @@ class BaseObject(ABC):
|
||||
*args: Positional arguments to pass to event handlers.
|
||||
**kwargs: Keyword arguments to pass to event handlers.
|
||||
"""
|
||||
if event_name not in self._event_handlers:
|
||||
# If we haven't registered an event handler, we don't need to do
|
||||
# anything.
|
||||
if not self._event_handlers.get(event_name):
|
||||
return
|
||||
|
||||
event_handler = self._event_handlers[event_name]
|
||||
# Create the task.
|
||||
task = asyncio.create_task(self._run_task(event_name, *args, **kwargs))
|
||||
|
||||
for handler in event_handler.handlers:
|
||||
if event_handler.is_sync:
|
||||
# Just run the handler.
|
||||
await self._run_handler(event_handler.name, handler, *args, **kwargs)
|
||||
else:
|
||||
# Create the task. Note that this is a task per each function
|
||||
# handler. Users can register to an event handler multiple
|
||||
# times.
|
||||
task = asyncio.create_task(
|
||||
self._run_handler(event_handler.name, handler, *args, **kwargs)
|
||||
)
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
|
||||
# Add it to our list of event tasks.
|
||||
self._event_tasks.add((event_name, task))
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
|
||||
# Remove the task from the event tasks list when the task completes.
|
||||
task.add_done_callback(self._event_task_finished)
|
||||
|
||||
async def _run_handler(self, event_name: str, handler, *args, **kwargs):
|
||||
async def _run_task(self, event_name: str, *args, **kwargs):
|
||||
"""Execute all handlers for an event.
|
||||
|
||||
Args:
|
||||
event_name: The event name for this handler.
|
||||
handler: The handler function to run.
|
||||
event_name: The name of the event being handled.
|
||||
*args: Positional arguments to pass to handlers.
|
||||
**kwargs: Keyword arguments to pass to handlers.
|
||||
"""
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(self, *args, **kwargs)
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
for handler in self._event_handlers[event_name]:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(self, *args, **kwargs)
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
||||
|
||||
|
||||
172
test_openai_agent.py
Normal file
172
test_openai_agent.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Simple test script for OpenAI Agent service."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock the OpenAI API key for testing
|
||||
os.environ["OPENAI_API_KEY"] = "test-key-for-testing"
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai_agent import OpenAIAgentService
|
||||
|
||||
|
||||
async def test_basic_functionality():
|
||||
"""Test basic OpenAI Agent service functionality."""
|
||||
print("🧪 Testing OpenAI Agent Service...")
|
||||
|
||||
# Create a simple weather tool for testing
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"The weather in {location} is sunny and 22°C."
|
||||
|
||||
try:
|
||||
# Create the service
|
||||
print("📋 Creating OpenAI Agent service...")
|
||||
service = OpenAIAgentService(
|
||||
name="Test Assistant",
|
||||
instructions="You are a helpful test assistant.",
|
||||
tools=[get_weather],
|
||||
api_key="test-key",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
print(f"✅ Service created successfully!")
|
||||
print(f" - Agent name: {service.agent.name}")
|
||||
print(f" - Model name: {service.model_name}")
|
||||
print(f" - Streaming enabled: {service._streaming}")
|
||||
|
||||
# Test basic configuration
|
||||
print("⚙️ Testing configuration updates...")
|
||||
service.update_agent_config(
|
||||
instructions="Updated test instructions",
|
||||
model_config={"model": "gpt-4o", "temperature": 0.5},
|
||||
)
|
||||
|
||||
print(f"✅ Configuration updated!")
|
||||
print(f" - New instructions: {service.agent.instructions}")
|
||||
print(f" - New model: {service.model_name}")
|
||||
|
||||
# Test session context
|
||||
print("💾 Testing session context...")
|
||||
service.update_session_context({"user_id": "test-user", "session": "test-session"})
|
||||
context = service.get_session_context()
|
||||
|
||||
print(f"✅ Session context managed!")
|
||||
print(f" - Context keys: {list(context.keys())}")
|
||||
|
||||
# Test adding tools
|
||||
print("🔧 Testing tool management...")
|
||||
|
||||
def get_time() -> str:
|
||||
"""Get current time."""
|
||||
return "The current time is 3:00 PM."
|
||||
|
||||
await service.add_tool(get_time)
|
||||
print(f"✅ Tool added successfully!")
|
||||
|
||||
print("\n🎉 All basic functionality tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_frame_processing():
|
||||
"""Test frame processing with mocked responses."""
|
||||
print("\n🔄 Testing frame processing...")
|
||||
|
||||
try:
|
||||
# Mock the Runner to avoid actual API calls
|
||||
with patch("pipecat.services.openai_agent.agent_service.Runner") as mock_runner:
|
||||
# Set up mock responses
|
||||
mock_stream_result = MagicMock()
|
||||
|
||||
# Mock stream events
|
||||
async def mock_stream_events():
|
||||
# Simulate streaming response
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="Hello "))
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="from "))
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="agent!"))
|
||||
|
||||
# Simulate completed message
|
||||
mock_item = MagicMock()
|
||||
mock_item.type = "message_output_item"
|
||||
mock_item.content = "Hello from agent!"
|
||||
yield MagicMock(type="run_item_stream_event", item=mock_item)
|
||||
|
||||
mock_stream_result.stream_events.return_value = mock_stream_events()
|
||||
mock_runner.run_streamed.return_value = mock_stream_result
|
||||
|
||||
# Create service with mocked runner
|
||||
service = OpenAIAgentService(
|
||||
name="Test Assistant",
|
||||
instructions="You are a helpful test assistant.",
|
||||
api_key="test-key",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Collect output frames
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
print(f" 📤 Frame: {type(frame).__name__}")
|
||||
if hasattr(frame, "text"):
|
||||
print(f" Text: '{frame.text}'")
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
print("📝 Processing text frame...")
|
||||
text_frame = TextFrame("Hello, how are you?")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait for async processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
print(f"✅ Frame processing completed!")
|
||||
print(f" - Generated {len(output_frames)} output frames")
|
||||
|
||||
# Check if we got expected frame types
|
||||
frame_types = [type(frame).__name__ for frame in output_frames]
|
||||
print(f" - Frame types: {frame_types}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Frame processing test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting OpenAI Agent Service Tests\n")
|
||||
|
||||
try:
|
||||
# Run basic functionality tests
|
||||
basic_test = await test_basic_functionality()
|
||||
|
||||
# Run frame processing tests
|
||||
frame_test = await test_frame_processing()
|
||||
|
||||
# Summary
|
||||
print(f"\n📊 Test Results:")
|
||||
print(f" - Basic functionality: {'✅ PASS' if basic_test else '❌ FAIL'}")
|
||||
print(f" - Frame processing: {'✅ PASS' if frame_test else '❌ FAIL'}")
|
||||
|
||||
if basic_test and frame_test:
|
||||
print(f"\n🎉 All tests passed! The OpenAI Agent service is working correctly.")
|
||||
else:
|
||||
print(f"\n⚠️ Some tests failed. Please check the output above.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test suite failed with error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
33
test_simple_agent.py
Normal file
33
test_simple_agent.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# Test the actual agents package API
|
||||
try:
|
||||
from agents import Agent, run
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
name="test-agent",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
print("✅ Agent created successfully!")
|
||||
print(f"Agent name: {agent.name}")
|
||||
|
||||
# Test a simple conversation
|
||||
async def test_agent():
|
||||
result = await run(agent, "Hello, how are you?")
|
||||
print(f"Agent response: {result}")
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_agent())
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -1,67 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
TextFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_interruption_and_wait(self):
|
||||
class DelayFrameProcessor(FrameProcessor):
|
||||
"""This processors just gives time to the event loop to change
|
||||
between tasks. Otherwise things happen to fast."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await asyncio.sleep(0.1)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class InterruptFrameProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(TransportMessageUrgentFrame(message=frame.text))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
# Just a random interruption to make sure we don't clear anything
|
||||
# before the actual `InterruptionTaskFrame` interruption.
|
||||
InterruptionFrame(),
|
||||
# This will generate an `InterruptionTaskFrame` and will wait for an
|
||||
# `InterruptionFrame`.
|
||||
TextFrame(text="Hello from Pipecat!"),
|
||||
# Just give time for everything to complete.
|
||||
SleepFrame(sleep=0.5),
|
||||
EndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
InterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
EndFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
send_end_frame=False,
|
||||
)
|
||||
286
tests/test_openai_agent_service.py
Normal file
286
tests/test_openai_agent_service.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for OpenAI Agent service."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add src to path for testing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""Mock Agent for testing."""
|
||||
|
||||
def __init__(self, name="Test Agent", instructions="Test instructions"):
|
||||
self.name = name
|
||||
self.instructions = instructions
|
||||
self.tools = []
|
||||
self.handoffs = []
|
||||
|
||||
|
||||
class MockRunResult:
|
||||
"""Mock RunResult for testing."""
|
||||
|
||||
def __init__(self, final_output="Test response"):
|
||||
self.final_output = final_output
|
||||
|
||||
|
||||
class MockStreamEvent:
|
||||
"""Mock StreamEvent for testing."""
|
||||
|
||||
def __init__(self, event_type, data=None, item=None):
|
||||
self.type = event_type
|
||||
self.data = data
|
||||
self.item = item
|
||||
|
||||
|
||||
class MockMessageItem:
|
||||
"""Mock message item for testing."""
|
||||
|
||||
def __init__(self, content="Test content"):
|
||||
self.type = "message_output_item"
|
||||
self.content = content
|
||||
|
||||
|
||||
class MockRunner:
|
||||
"""Mock Runner for testing."""
|
||||
|
||||
@staticmethod
|
||||
async def run(agent, input_text, context=None):
|
||||
return MockRunResult("Mocked response")
|
||||
|
||||
@staticmethod
|
||||
def run_streamed(agent, input_text, context=None):
|
||||
class MockStreamResult:
|
||||
async def stream_events(self):
|
||||
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="Test "))
|
||||
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="response"))
|
||||
yield MockStreamEvent(
|
||||
"run_item_stream_event", item=MockMessageItem("Test response")
|
||||
)
|
||||
|
||||
return MockStreamResult()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_agents():
|
||||
"""Mock the OpenAI Agents SDK imports."""
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"agents": MagicMock(),
|
||||
"agents.stream_events": MagicMock(),
|
||||
"agents.result": MagicMock(),
|
||||
},
|
||||
):
|
||||
# Mock the classes and functions we need
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.return_value = MockAgent()
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.run = AsyncMock(return_value=MockRunResult())
|
||||
mock_runner.run_streamed = MagicMock(return_value=MockRunner.run_streamed(None, None))
|
||||
|
||||
with (
|
||||
patch("pipecat.services.openai_agent.agent_service.Agent", mock_agent),
|
||||
patch("pipecat.services.openai_agent.agent_service.Runner", mock_runner),
|
||||
):
|
||||
yield {
|
||||
"Agent": mock_agent,
|
||||
"Runner": mock_runner,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_init(mock_openai_agents):
|
||||
"""Test OpenAI Agent service initialization."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
|
||||
)
|
||||
|
||||
assert service.agent.name == "Test Agent"
|
||||
assert service._streaming is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_process_text_frame_streaming(mock_openai_agents):
|
||||
"""Test processing text frame with streaming enabled."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
|
||||
)
|
||||
|
||||
# Mock the push_frame method to capture output
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
text_frame = TextFrame("Hello, agent!")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait a bit for async processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that appropriate frames were generated
|
||||
assert len(output_frames) > 0
|
||||
assert any(isinstance(frame, LLMFullResponseStartFrame) for frame in output_frames)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_process_text_frame_non_streaming(mock_openai_agents):
|
||||
"""Test processing text frame with streaming disabled."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=False
|
||||
)
|
||||
|
||||
# Mock the push_frame method to capture output
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
text_frame = TextFrame("Hello, agent!")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait a bit for async processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that appropriate frames were generated
|
||||
assert len(output_frames) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_update_config(mock_openai_agents):
|
||||
"""Test updating agent configuration."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Update configuration
|
||||
service.update_agent_config(
|
||||
instructions="Updated instructions", model_config={"model": "gpt-4o", "temperature": 0.7}
|
||||
)
|
||||
|
||||
assert service.agent.instructions == "Updated instructions"
|
||||
assert service.agent.model_config["model"] == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_session_context(mock_openai_agents):
|
||||
"""Test session context management."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent",
|
||||
instructions="Test instructions",
|
||||
api_key="test-key",
|
||||
session_config={"user_id": "test-user"},
|
||||
)
|
||||
|
||||
# Get initial context
|
||||
context = service.get_session_context()
|
||||
assert context["user_id"] == "test-user"
|
||||
|
||||
# Update context
|
||||
service.update_session_context({"session_id": "test-session"})
|
||||
|
||||
updated_context = service.get_session_context()
|
||||
assert updated_context["user_id"] == "test-user"
|
||||
assert updated_context["session_id"] == "test-session"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_add_tools(mock_openai_agents):
|
||||
"""Test adding tools to the agent."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Define a test tool
|
||||
def test_tool():
|
||||
return "test result"
|
||||
|
||||
# Add the tool
|
||||
await service.add_tool(test_tool)
|
||||
|
||||
# Check if tool was added (this depends on the mock implementation)
|
||||
assert hasattr(service.agent, "tools")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_lifecycle(mock_openai_agents):
|
||||
"""Test service lifecycle methods."""
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Test start
|
||||
start_frame = StartFrame()
|
||||
await service.start(start_frame)
|
||||
|
||||
# Test cancel
|
||||
cancel_frame = CancelFrame()
|
||||
await service.cancel(cancel_frame)
|
||||
|
||||
# Test stop
|
||||
end_frame = EndFrame()
|
||||
await service.stop(end_frame)
|
||||
|
||||
|
||||
def test_openai_agent_service_import_error():
|
||||
"""Test that import error is handled gracefully."""
|
||||
# Mock the import to fail
|
||||
with patch.dict("sys.modules", {"agents": None}):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# This should trigger the import error
|
||||
import importlib
|
||||
|
||||
import pipecat.services.openai_agent.agent_service
|
||||
|
||||
importlib.reload(pipecat.services.openai_agent.agent_service)
|
||||
|
||||
assert "Missing module" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1,303 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Unit tests for ServiceSwitcher and related components."""
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, ServiceSwitcherStrategyManual
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
class MockFrameProcessor(FrameProcessor):
|
||||
"""A test frame processor that tracks which frames it has processed."""
|
||||
|
||||
def __init__(self, test_name: str, **kwargs):
|
||||
"""Initialize the test processor with a name.
|
||||
|
||||
Args:
|
||||
test_name: A unique name for this processor instance.
|
||||
**kwargs: Additional arguments passed to the parent FrameProcessor.
|
||||
"""
|
||||
super().__init__(name=test_name, **kwargs)
|
||||
self.test_name = test_name
|
||||
self.processed_frames = []
|
||||
self.frame_count = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an incoming frame and track it.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
self.processed_frames.append(frame)
|
||||
self.frame_count += 1
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def reset_counters(self):
|
||||
"""Reset the frame tracking counters."""
|
||||
self.processed_frames = []
|
||||
self.frame_count = 0
|
||||
|
||||
|
||||
class TestServiceSwitcherStrategyManual(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcherStrategyManual."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_services(self):
|
||||
"""Test initialization with a list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
self.assertEqual(strategy.services, self.services)
|
||||
self.assertEqual(strategy.active_service, self.service1) # First service should be active
|
||||
|
||||
def test_init_with_empty_services(self):
|
||||
"""Test initialization with an empty list of services."""
|
||||
strategy = ServiceSwitcherStrategyManual([])
|
||||
|
||||
self.assertEqual(strategy.services, [])
|
||||
self.assertIsNone(strategy.active_service)
|
||||
|
||||
def test_handle_manually_switch_service_frame(self):
|
||||
"""Test manual service switching with ManuallySwitchServiceFrame."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
|
||||
# Initially service1 should be active
|
||||
self.assertEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
|
||||
# Switch to service2
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service2)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertEqual(strategy.active_service, self.service2)
|
||||
self.assertNotEqual(strategy.active_service, self.service3)
|
||||
|
||||
# Switch to service3
|
||||
switch_frame = ManuallySwitchServiceFrame(service=self.service3)
|
||||
strategy.handle_frame(switch_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertNotEqual(strategy.active_service, self.service1)
|
||||
self.assertNotEqual(strategy.active_service, self.service2)
|
||||
self.assertEqual(strategy.active_service, self.service3)
|
||||
|
||||
def test_handle_frame_unsupported_frame_type(self):
|
||||
"""Test that unsupported frame types raise an error."""
|
||||
strategy = ServiceSwitcherStrategyManual(self.services)
|
||||
unsupported_frame = TextFrame(text="test") # Not a ServiceSwitcherFrame
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
strategy.handle_frame(unsupported_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
self.assertIn("Unsupported frame type", str(context.exception))
|
||||
|
||||
|
||||
class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceSwitcher."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.service1 = MockFrameProcessor("service1")
|
||||
self.service2 = MockFrameProcessor("service2")
|
||||
self.service3 = MockFrameProcessor("service3")
|
||||
self.services = [self.service1, self.service2, self.service3]
|
||||
|
||||
def test_init_with_manual_strategy(self):
|
||||
"""Test initialization with manual strategy."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
self.assertEqual(switcher.services, self.services)
|
||||
self.assertIsInstance(switcher.strategy, ServiceSwitcherStrategyManual)
|
||||
self.assertEqual(switcher.strategy.services, self.services)
|
||||
|
||||
async def test_default_active_service(self):
|
||||
"""Test that the initially-active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
service.reset_counters()
|
||||
|
||||
# Send some test frames
|
||||
frames_to_send = [
|
||||
TextFrame(text="Hello 1"),
|
||||
TextFrame(text="Hello 2"),
|
||||
TextFrame(text="Hello 3"),
|
||||
]
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[TextFrame, TextFrame, TextFrame],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
# Only service1 should have processed the text frames
|
||||
# Note: The service also receives StartFrame and EndFrame, so count those too
|
||||
text_frames = [f for f in self.service1.processed_frames if isinstance(f, TextFrame)]
|
||||
self.assertEqual(len(text_frames), 3)
|
||||
|
||||
# Check that other services don't receive text frames (they might get StartFrame/EndFrame)
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service3_text_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
self.assertEqual(len(service2_text_frames), 0)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
# Verify the actual text frames processed
|
||||
for i, frame in enumerate(text_frames):
|
||||
self.assertEqual(frame.text, f"Hello {i + 1}")
|
||||
|
||||
async def test_service_switching(self):
|
||||
"""Test that after service switching using ManuallySwitchServiceFrame, the new active service receives frames while others don't."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters
|
||||
for service in self.services:
|
||||
service.reset_counters()
|
||||
|
||||
# Send a test frame, a switch frame, and another test frame
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[
|
||||
TextFrame("Hello 1"),
|
||||
ManuallySwitchServiceFrame(service=self.service2),
|
||||
TextFrame("Hello 2"),
|
||||
],
|
||||
expected_down_frames=[TextFrame, ManuallySwitchServiceFrame, TextFrame],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
# Verify service2 received the frame
|
||||
service1_text_frames = [
|
||||
f for f in self.service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service2_text_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
service3_text_frames = [
|
||||
f for f in self.service3.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
self.assertEqual(len(service1_text_frames), 1)
|
||||
self.assertEqual(len(service2_text_frames), 1)
|
||||
self.assertEqual(len(service3_text_frames), 0)
|
||||
|
||||
self.assertEqual(service1_text_frames[0].text, "Hello 1")
|
||||
self.assertEqual(service2_text_frames[0].text, "Hello 2")
|
||||
|
||||
async def test_multi_service_switcher_targeting(self):
|
||||
"""Test that ManuallySwitchServiceFrame targets the correct ServiceSwitcher in a multi-switcher pipeline."""
|
||||
# Create services for first switcher
|
||||
switcher1_service1 = MockFrameProcessor("switcher1_service1")
|
||||
switcher1_service2 = MockFrameProcessor("switcher1_service2")
|
||||
switcher1_services = [switcher1_service1, switcher1_service2]
|
||||
|
||||
# Create services for second switcher
|
||||
switcher2_service1 = MockFrameProcessor("switcher2_service1")
|
||||
switcher2_service2 = MockFrameProcessor("switcher2_service2")
|
||||
switcher2_services = [switcher2_service1, switcher2_service2]
|
||||
|
||||
# Create two service switchers
|
||||
switcher1 = ServiceSwitcher(switcher1_services, ServiceSwitcherStrategyManual)
|
||||
switcher2 = ServiceSwitcher(switcher2_services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Create a pipeline with both switchers: switcher1 -> switcher2
|
||||
pipeline = Pipeline([switcher1, switcher2])
|
||||
|
||||
# Reset counters
|
||||
for service in switcher1_services + switcher2_services:
|
||||
service.reset_counters()
|
||||
|
||||
# Initially, both switchers should use their first services
|
||||
self.assertEqual(switcher1.strategy.active_service, switcher1_service1)
|
||||
self.assertEqual(switcher2.strategy.active_service, switcher2_service1)
|
||||
|
||||
# Send frames to test the pipeline:
|
||||
# 1. Text frame (should go through both switchers' active services)
|
||||
# 2. Switch frame targeting switcher1's second service
|
||||
# 3. Text frame (should go through switcher1's new service and switcher2's original service)
|
||||
# 4. Switch frame targeting switcher2's second service
|
||||
# 5. Text frame (should go through switcher1's current service and switcher2's new service)
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[
|
||||
TextFrame("Before any switches"),
|
||||
ManuallySwitchServiceFrame(service=switcher1_service2), # Switch first switcher
|
||||
TextFrame("After switching first switcher"),
|
||||
ManuallySwitchServiceFrame(service=switcher2_service2), # Switch second switcher
|
||||
TextFrame("After switching second switcher"),
|
||||
],
|
||||
expected_down_frames=[
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
],
|
||||
expected_up_frames=[], # Expect no error frames
|
||||
)
|
||||
|
||||
# Verify the active services changed correctly
|
||||
self.assertEqual(switcher1.strategy.active_service, switcher1_service2)
|
||||
self.assertEqual(switcher2.strategy.active_service, switcher2_service2)
|
||||
|
||||
# Verify frame distribution:
|
||||
# First text frame should go through switcher1_service1 and switcher2_service1
|
||||
switcher1_service1_texts = [
|
||||
f for f in switcher1_service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
switcher2_service1_texts = [
|
||||
f for f in switcher2_service1.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
# Second text frame should go through switcher1_service2 and switcher2_service1
|
||||
switcher1_service2_texts = [
|
||||
f for f in switcher1_service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
# Third text frame should go through switcher1_service2 and switcher2_service2
|
||||
switcher2_service2_texts = [
|
||||
f for f in switcher2_service2.processed_frames if isinstance(f, TextFrame)
|
||||
]
|
||||
|
||||
# Verify frame counts and content
|
||||
self.assertEqual(len(switcher1_service1_texts), 1)
|
||||
self.assertEqual(switcher1_service1_texts[0].text, "Before any switches")
|
||||
|
||||
self.assertEqual(len(switcher1_service2_texts), 2)
|
||||
self.assertEqual(switcher1_service2_texts[0].text, "After switching first switcher")
|
||||
self.assertEqual(switcher1_service2_texts[1].text, "After switching second switcher")
|
||||
|
||||
self.assertEqual(len(switcher2_service1_texts), 2)
|
||||
self.assertEqual(switcher2_service1_texts[0].text, "Before any switches")
|
||||
self.assertEqual(switcher2_service1_texts[1].text, "After switching first switcher")
|
||||
|
||||
self.assertEqual(len(switcher2_service2_texts), 1)
|
||||
self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user