Compare commits

...

25 Commits

Author SHA1 Message Date
James Hush
29d4a56663 Working on the 46 example 2025-09-17 11:59:16 +08:00
James Hush
373a09ecd6 Working on the 46 example 2025-09-17 11:59:10 +08:00
James Hush
07f54c48f3 This is working 2025-09-17 11:53:07 +08:00
James Hush
c8a3d65aa4 Save progress 2025-09-17 11:39:21 +08:00
James Hush
50a2a0dc86 ok its kinda working 2025-09-17 11:29:11 +08:00
James Hush
0421d97954 Save changes 2025-09-17 11:09:03 +08:00
James Hush
54c8f336c3 Save progress 2025-09-16 16:43:38 +08:00
James Hush
b086fbafe6 feat: Add OpenAI Agents SDK integration service
- Create new OpenAIAgentService that integrates OpenAI Agents SDK with Pipecat
- Support for agent loops, handoffs, guardrails, and session management
- Add streaming and non-streaming response modes
- Include comprehensive tool integration and error handling
- Add optional dependency for openai-agents package
- Create foundational examples showing basic usage and agent handoffs
- Add comprehensive tests with mocked dependencies
- Include detailed documentation and README

Key features:
- Real-time streaming responses compatible with Pipecat pipelines
- Agent handoffs for specialized task delegation
- Tool calling with automatic schema generation
- Input/output guardrails for safety and validation
- Session context management for conversation continuity
- Built-in tracing and monitoring integration

Examples:
- 45-openai-agent-basic.py: Basic agent with weather and trivia tools
- 46-openai-agent-handoffs.py: Multi-agent system with specialist handoffs
2025-09-16 16:20:30 +08:00
Mark Backman
cca90791c4 Merge pull request #2652 from pipecat-ai/mb/fix-audio-buffer-processor-has-audio
fix: AudioBufferProcessor has_audio returns based on user or bot audi…
2025-09-15 18:43:59 -07:00
Mark Backman
f2a5d408de fix: AudioBufferProcessor has_audio returns based on user or bot audio existing 2025-09-15 21:35:35 -04:00
Aleix Conchillo Flaqué
044c6eba46 Merge pull request #2655 from pipecat-ai/aleix/add-on-pipeline-finalized
PipelineTask: add on_pipeline_finished event
2025-09-15 15:32:04 -07:00
Aleix Conchillo Flaqué
db71089f5e PipelineTask: add on_pipeline_finished event
This deprecates `on_pipeline_stopped`, `on_pipeline_ended` and
`on_pipeline_cancelled`.
2025-09-15 15:28:33 -07:00
Aleix Conchillo Flaqué
f861f5066f Merge pull request #2654 from pipecat-ai/aleix/unify-on-client-disconnected
transports: on_client_disconnected only if remote client disconnects
2025-09-15 15:18:57 -07:00
kompfner
81cede2c60 Merge pull request #2653 from pipecat-ai/pk/llm-context-adapting-tests
`LLMContext`-adapting unit tests
2025-09-15 16:38:46 -04:00
kompfner
7603203230 Merge pull request #2644 from pipecat-ai/pk/run-inference-unit-tests
`run_inference` unit tests
2025-09-15 16:26:10 -04:00
Aleix Conchillo Flaqué
8569b61598 transports: on_client_disconnected only if remote client disconnects 2025-09-15 11:35:40 -07:00
Paul Kompfner
fe42187dc1 Implement LLMService.create_llm_specific_message() so that users don't need to just know what value of llm to provide to the LLMSpecificMessage constructor 2025-09-15 14:15:22 -04:00
Paul Kompfner
999e88c942 Add unit tests for AWSBedrockLLMAdapter.get_llm_invocation_params(), focusing on messages specifically 2025-09-15 12:08:21 -04:00
Paul Kompfner
c04df2f28b Add unit tests for AnthropicLLMAdapter.get_llm_invocation_params(), focusing on messages specifically 2025-09-15 11:55:48 -04:00
Paul Kompfner
100ef0ab5c Add unit tests for GeminiLLMAdapter.get_llm_invocation_params(), focusing on messages specifically 2025-09-15 11:38:23 -04:00
Paul Kompfner
42886d7105 Add unit tests for OpenAILLMAdapter.get_llm_invocation_params(), focusing on messages specifically. Also, fix a bug in OpenAILLMAdapter (found thanks to the unit tests) where it wasn't "unwrapping" LLMSpecificMessages. 2025-09-15 11:17:11 -04:00
Mark Backman
22cbba002a Merge pull request #2651 from pipecat-ai/mb/heygen-bot-speaking-frame
fix: push BotStartedSpeakingFrame in HeyGenVideoService
2025-09-15 08:02:25 -07:00
Mark Backman
c873798ce5 fix: push BotStartedSpeakingFrame in HeyGenVideoService 2025-09-14 08:12:44 -04:00
Paul Kompfner
786387722a Fix an issue in AWSBedrockLLMService.run_inference—exceptions should propagate, just like with other LLM services 2025-09-12 11:09:32 -04:00
Paul Kompfner
9f82c6b4a4 Add unit tests for run_inference 2025-09-12 11:07:11 -04:00
29 changed files with 7447 additions and 4056 deletions

285
AGENTS.md Normal file
View File

@@ -0,0 +1,285 @@
# AGENTS.md
## Project Overview
Pipecat is an open-source Python framework for building real-time voice and multimodal conversational AI agents. The codebase is organized around a pipeline architecture where data flows through connected services (STT → LLM → TTS).
## Development Environment Setup
### Prerequisites
- **Minimum Python Version:** 3.10
- **Recommended Python Version:** 3.12
- **Package Manager:** uv (recommended) or pip
### Setup Commands
```bash
# Clone the repository
git clone https://github.com/pipecat-ai/pipecat.git
cd pipecat
# Install dependencies with uv (recommended)
uv sync --group dev --all-extras \
--no-extra gstreamer \
--no-extra krisp \
--no-extra local \
--no-extra ultravox
# Or with pip
pip install -e ".[dev]"
# Install pre-commit hooks
uv run pre-commit install
# Set up environment variables
cp env.example .env
```
## Build and Test Commands
### Running Tests
```bash
# Run all tests
uv run pytest
# Run specific test file
uv run pytest tests/test_name.py
# Run tests with coverage
uv run pytest --cov=pipecat --cov-report=html
```
### Code Quality
```bash
# Format code (required before commits)
uv run ruff format
# Lint code
uv run ruff check
# Type checking
uv run mypy src/pipecat
# Run pre-commit checks manually
uv run pre-commit run --all-files
```
### Documentation
```bash
# Build API documentation
cd docs/api
./build-docs.sh
# Build docs manually
sphinx-build -b html . _build/html -W --keep-going
```
## Code Style Guidelines
### Python Standards
- **Formatting:** Strict PEP 8 via Ruff
- **Docstrings:** Google-style format
- **Type Hints:** Required for all public APIs
- **Import Organization:** Automated via Ruff
### Docstring Conventions
- **Classes:** Describe purpose + `__init__` with complete `Args:` section
- **Dataclasses:** Use `Parameters:` section, no `__init__` docstring
- **Methods:** Include `Args:` and `Returns:` sections
- **Properties:** Must have `Returns:` section
- **Examples:** Use `Examples:` section with `::` syntax
### File Organization
```
src/pipecat/ # Main package
├── processors/ # Frame processors
├── services/ # AI service integrations
├── transports/ # Communication layers
├── frames/ # Data frame definitions
└── pipeline/ # Pipeline orchestration
examples/foundational/ # Step-by-step tutorials
tests/ # Test suite
```
## Testing Instructions
### Test Structure
- **Unit Tests:** Test individual components in isolation
- **Integration Tests:** Test service interactions
- **Example Tests:** Validate foundational examples work
### Adding Tests
```bash
# Test naming convention
test_<component>_<functionality>.py
# Run specific test pattern
uv run pytest -k "test_pipeline"
# Run with debugging
uv run pytest -s -vv tests/test_name.py::test_function
```
### Pre-commit Requirements
All commits must pass:
- Ruff formatting
- Ruff linting
- Type checking
- Basic test suite
## Dependency Management
### Using uv (Recommended)
```bash
# Add runtime dependency
uv add package-name
# Add optional dependency
uv add --optional service package-name
# Add development dependency
uv add --group dev package-name
# Update lockfile
uv lock
# Sync dependencies
uv sync
```
### Important Notes
- **Always commit both `pyproject.toml` and `uv.lock` together**
- **Never manually edit `uv.lock`** - it's auto-generated
- **Use extras for optional service dependencies** (e.g., `[openai]`, `[cartesia]`)
## Project Structure Guidelines
### Service Integration
When adding new AI services:
1. Create service class in `src/pipecat/services/<provider>/`
2. Follow existing patterns (e.g., STTService, LLMService)
3. Add to appropriate extras in `pyproject.toml`
4. Include tests in `tests/`
5. Add documentation examples
### Frame Processing
For custom processors:
1. Inherit from `FrameProcessor`
2. Implement `process_frame()` method. ALWAYS explicitly call `await super().process_frame(frame, direction)` at the top of this method.
3. Handle frame direction (FrameDirection.UPSTREAM/DOWNSTREAM)
4. Add proper type hints and docstrings
### Transport Implementation
For new transport layers:
1. Inherit from `BaseTransport`
2. Implement required abstract methods
3. Handle connection lifecycle
4. Support both input and output streams
## Security Considerations
### API Keys
- **Never commit API keys** to the repository
- **Use environment variables** for all secrets
- **Reference `env.example`** for required variables
- **Use `.env` files** for local development
### Input Validation
- **Validate all external inputs** (audio, text, API responses)
- **Sanitize user data** before processing
- **Handle rate limiting** for external services
- **Implement proper timeout handling**
## Performance Guidelines
### Memory Management
- **Clean up resources** in transport disconnection handlers
- **Use async context managers** for service connections
- **Implement proper frame lifecycle** management
### Latency Optimization
- **Choose appropriate STT services** for latency requirements
- **Use streaming TTS** when possible
- **Implement connection pooling** for HTTP services
- **Consider WebRTC** for real-time applications
## Common Patterns
### Error Handling
```python
@transport.event_handler("on_error")
async def on_error(transport, error):
logger.error(f"Transport error: {error}")
# Shutdown the pipeline
await task.queue_frame(EndFrame())
```
### Service Configuration
```python
# Use environment variables for configuration
service = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY", ""),
model="gpt-4o",
params={"temperature": 0.7}
)
```
### Pipeline Assembly
```python
pipeline = Pipeline([
transport.input(),
stt_service,
context_aggregator.user(),
llm_service,
tts_service,
transport.output(),
context_aggregator.assistant(),
])
```
## Commit and PR Guidelines
### Commit Message Format
```
<type>(<scope>): <description>
[optional body]
[optional footer]
```
Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
### PR Requirements
- **All tests must pass**
- **Code must be properly formatted** (Ruff)
- **Include appropriate tests** for new functionality
- **Update documentation** if needed
- **Reference related issues** in description
### Review Process
1. Automated checks must pass
2. Manual code review by maintainers
3. Documentation review for user-facing changes
4. Integration testing for service additions
## Troubleshooting
### Common Issues
- **Import errors:** Run `uv sync` to ensure dependencies are installed
- **Test failures:** Check environment variables in `.env`
- **Format errors:** Run `uv run ruff format` before committing
- **Type errors:** Ensure all public methods have type hints
### Development Tips
- **Use foundational examples** as starting points for testing
- **Check existing services** for integration patterns
- **Run tests frequently** during development
- **Use IDE integration** for Ruff formatting
### Getting Help
- **Documentation:** [docs.pipecat.ai](https://docs.pipecat.ai)
- **Issues:** [GitHub Issues](https://github.com/pipecat-ai/pipecat/issues)

View File

@@ -5,6 +5,39 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
fired when the pipeline is done running. This can be the result of a
`StopFrame`, `CancelFrame` or `EndFrame`.
```python
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
...
```
### Deprecated
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
instead.
### Fixed
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
when a bot speaks and user input is blocked.
- Fixed a `FastAPIWebsocketTransport` and `SmallWebRTCTransport` issue where
`on_client_disconnected` would be triggered when the bot ends the
conversation. That is, `on_client_disconnected` should only be triggered when
the remote client actually disconnects.
- Fixed an issue in `HeyGenVideoService` where the `BotStartedSpeakingFrame`
was blocked from moving through the Pipeline.
## [0.0.85] - 2025-09-12
### Added

View File

@@ -0,0 +1,205 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""
Basic OpenAI Agent service example.
This example demonstrates how to use the OpenAI Agents SDK within a Pipecat
pipeline to create an interactive agent with tool calling capabilities.
Requirements:
- OpenAI API key
- OpenAI Agents SDK: pip install openai-agents
"""
import os
import random
from typing import Any, List
# Import agents SDK for tools and agent creation
from agents import Agent, function_tool
from dotenv import load_dotenv
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from pipecat.frames.frames import LLMRunFrame, TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# Transport configuration
transport_params = {
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
}
@function_tool
def get_weather(location: str) -> str:
"""Get the current weather for a location.
Args:
location: The location to get weather for
Returns:
A weather description string
"""
# Mock weather data - in real usage, integrate with weather API
weather_data = {
"San Francisco": "Foggy, 65°F",
"New York": "Sunny, 72°F",
"London": "Rainy, 59°F",
"Tokyo": "Partly cloudy, 68°F",
}
return weather_data.get(location, f"Weather data not available for {location}")
@function_tool
def get_random_fact() -> str:
"""Get a random interesting fact.
Returns:
A random fact string
"""
facts = [
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
"Octopuses have three hearts and blue blood.",
"The Great Wall of China isn't visible from space with the naked eye.",
"Bananas are berries, but strawberries aren't.",
]
return random.choice(facts)
def get_random_fact_tool():
"""Example tool function for random facts."""
def get_random_fact() -> str:
"""Get a random interesting fact.
Returns:
A random fact string.
"""
facts = [
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
"A group of flamingos is called a 'flamboyance'.",
"Octopuses have three hearts and blue blood.",
"The Great Wall of China isn't visible from space with the naked eye.",
"Bananas are berries, but strawberries aren't.",
]
return random.choice(facts)
return get_random_fact
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info("Starting OpenAI Agent bot")
# Set up STT for speech recognition
stt = DeepgramSTTService(
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
model="nova-2",
)
# Set up TTS for voice output
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY", ""),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
# Create tools for the agent
tools: list[Any] = [
get_weather,
get_random_fact,
]
# Create the agent with tools
agent = Agent(
name="Assistant",
instructions="""You are a helpful assistant with access to weather information and random facts.
You can:
- Check weather for any location using the get_weather tool
- Share interesting facts using the get_random_fact tool
- Have natural conversations
Be friendly, informative, and engaging in your responses.""",
tools=tools,
)
# Initialize the OpenAI Agent service with the pre-configured agent
agent_service = OpenAIAgentService(
agent=agent,
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Set up conversation context with initial system message
messages: List[ChatCompletionMessageParam] = [
{
"role": "system",
"content": "You are a helpful assistant with access to weather information and random facts. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = agent_service.create_context_aggregator(context)
# Create the processing pipeline with context aggregators
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # Speech to text
context_aggregator.user(), # User responses
agent_service, # OpenAI Agent processing
tts, # Text to speech
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
# Send an initial greeting when client connects
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info("Client connected, sending greeting")
# Kick off the conversation by adding system message and running LLM
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info("Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -0,0 +1,276 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""
Advanced OpenAI Agent service example with handoffs.
This example demonstrates how to use multiple agents with handoffs in the
OpenAI Agents SDK within a Pipecat pipeline, showcasing agent orchestration
and specialization.
Requirements:
- OpenAI API key
- OpenAI Agents SDK: pip install openai-agents
"""
import os
import random
from typing import Any, Dict, List
from dotenv import load_dotenv
from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from pipecat.frames.frames import LLMRunFrame, TextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# Transport configuration
transport_params = {
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
}
def create_weather_tools():
"""Create weather-related tools."""
def get_weather(location: str) -> str:
"""Get current weather for a location."""
conditions = ["sunny", "cloudy", "rainy", "snowy", "windy"]
temp = random.randint(-10, 35)
condition = random.choice(conditions)
return f"The weather in {location} is {condition} with a temperature of {temp}°C."
def get_forecast(location: str, days: int = 3) -> str:
"""Get weather forecast for multiple days."""
forecast = []
for i in range(days):
conditions = ["sunny", "cloudy", "rainy", "snowy"]
temp = random.randint(-5, 30)
condition = random.choice(conditions)
day = "today" if i == 0 else f"in {i} day{'s' if i > 1 else ''}"
forecast.append(f"{day.capitalize()}: {condition}, {temp}°C")
return f"Weather forecast for {location}:\n" + "\n".join(forecast)
return [get_weather, get_forecast]
def create_trivia_tools():
"""Create trivia and fact tools."""
def get_random_fact() -> str:
"""Get a random interesting fact."""
facts = [
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
"A group of flamingos is called a 'flamboyance'.",
"Octopuses have three hearts and blue blood.",
"The Great Wall of China isn't visible from space with the naked eye.",
"Bananas are berries, but strawberries aren't.",
"Wombat poop is cube-shaped.",
"A shrimp's heart is in its head.",
"It's impossible to hum while holding your nose.",
]
return random.choice(facts)
def get_science_fact() -> str:
"""Get a random science fact."""
facts = [
"The speed of light in a vacuum is approximately 299,792,458 meters per second.",
"DNA stands for Deoxyribonucleic Acid.",
"The human brain uses about 20% of the body's total energy.",
"There are more possible games of chess than atoms in the observable universe.",
"A single bolt of lightning contains enough energy to toast 100,000 slices of bread.",
]
return random.choice(facts)
return [get_random_fact, get_science_fact]
def create_math_tools():
"""Create math calculation tools."""
def calculate(expression: str) -> str:
"""Safely calculate a mathematical expression."""
try:
# Only allow basic math operations for safety
allowed_chars = set("0123456789+-*/.() ")
if not all(c in allowed_chars for c in expression):
return "Sorry, I can only calculate basic math expressions with +, -, *, /, and parentheses."
result = eval(expression)
return f"{expression} = {result}"
except Exception as e:
return f"Error calculating '{expression}': {str(e)}"
def generate_math_problem() -> str:
"""Generate a random math problem."""
operations = ["+", "-", "*"]
a = random.randint(1, 20)
b = random.randint(1, 20)
op = random.choice(operations)
if op == "+":
answer = a + b
elif op == "-":
answer = a - b
else: # multiplication
answer = a * b
return f"Here's a math problem for you: {a} {op} {b} = ?"
return [calculate, generate_math_problem]
async def create_specialist_agents():
"""Create specialized agents for different domains."""
# Weather specialist agent
weather_agent = OpenAIAgentService(
name="Weather Specialist",
instructions="""You are a weather specialist. You provide detailed weather information,
forecasts, and weather-related advice. Use your tools to get accurate weather data.
Be informative and helpful about weather conditions and what they might mean for
outdoor activities.""",
tools=create_weather_tools(),
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Trivia specialist agent
trivia_agent = OpenAIAgentService(
name="Trivia Master",
instructions="""You are a trivia and facts specialist. You love sharing interesting
facts, trivia, and educational content. Use your tools to provide fascinating
information and engage users with fun facts. Make learning enjoyable!""",
tools=create_trivia_tools(),
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Math specialist agent
math_agent = OpenAIAgentService(
name="Math Helper",
instructions="""You are a mathematics specialist. You help with calculations,
math problems, and mathematical concepts. Use your tools to solve problems
and generate practice questions. Make math accessible and fun!""",
tools=create_math_tools(),
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
return weather_agent, trivia_agent, math_agent
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info("Starting OpenAI Agent bot with handoffs")
# Set up STT for speech recognition
stt = DeepgramSTTService(
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
model="nova-2",
)
# Set up TTS for voice output
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY", ""),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
# Create specialist agents
weather_agent, trivia_agent, math_agent = await create_specialist_agents()
# Create the main triage agent that can hand off to specialists
triage_agent = OpenAIAgentService(
name="Assistant Coordinator",
instructions="""You are a helpful assistant coordinator. Your role is to understand
what the user needs and direct them to the right specialist:
- For weather questions, forecasts, or outdoor activity planning -> Weather Specialist
- For interesting facts, trivia, or educational content -> Trivia Master
- For calculations, math problems, or mathematical help -> Math Helper
If the request doesn't clearly fit a specialist, you can handle general conversation
yourself. Always be friendly and explain when you're connecting them to a specialist.""",
handoffs=[weather_agent.agent, trivia_agent.agent, math_agent.agent], # type: ignore
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Set up conversation context with initial system message
messages: List[ChatCompletionMessageParam] = [
{
"role": "system",
"content": "You are a helpful assistant coordinator with access to weather information, trivia, and math tools. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = triage_agent.create_context_aggregator(context)
# Create the processing pipeline with context aggregators
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # Speech to text
context_aggregator.user(), # User responses
triage_agent, # OpenAI Agent processing
tts, # Text to speech
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
# Send an initial greeting when client connects
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info("Client connected, sending greeting")
# Kick off the conversation by adding system message and running LLM
messages.append(
{
"role": "system",
"content": "Please introduce yourself to the user as an AI assistant coordinator who works with specialists for weather, trivia, and math topics.",
}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info("Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -34,7 +34,7 @@ dependencies = [
"pyloudnorm~=0.1.1",
"resampy~=0.4.3",
"soxr~=0.5.0",
"openai>=1.74.0,<=1.99.1",
"openai>=1.74.0,<2.0.0",
# Pinning numba to resolve package dependencies
"numba==0.61.2",
"wait_for2>=0.4.1; python_version<'3.12'",
@@ -74,7 +74,7 @@ langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-ope
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
lmnt = [ "websockets>=13.1,<15.0" ]
local = [ "pyaudio~=0.2.14" ]
mcp = [ "mcp[cli]~=1.9.4" ]
mcp = [ "mcp[cli]>=1.11.0,<2.0.0" ]
mem0 = [ "mem0ai~=0.1.94" ]
mistral = []
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
@@ -83,7 +83,8 @@ nim = []
neuphonic = [ "websockets>=13.1,<15.0" ]
noisereduce = [ "noisereduce~=3.0.3" ]
openai = [ "websockets>=13.1,<15.0" ]
openpipe = [ "openpipe~=4.50.0" ]
openai-agent = [ "openai-agents~=0.3.0" ]
# openpipe = [ "openpipe~=4.50.0" ] # Temporarily disabled due to openai version conflict
openrouter = []
perplexity = []
playht = [ "websockets>=13.1,<15.0" ]

View File

@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMSpecificMessage,
NotGiven,
)
# Should be a TypedDict
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
@@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
Subclasses must implement provider-specific conversion logic.
"""
@property
@abstractmethod
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
Returns:
The identifier string.
"""
pass
@abstractmethod
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
"""Get provider-specific LLM invocation parameters from a universal LLM context.
@@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
"""
pass
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
"""Get messages from the LLM context, including standard and LLM-specific messages.
Args:
context: The LLM context containing messages.
Returns:
List of messages including standard and LLM-specific messages.
"""
return context.get_messages(self.id_for_llm_specific_messages)
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
"""Convert tools from standard format to provider format.

View File

@@ -42,6 +42,11 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
to the specific format required by Anthropic's Claude models for function calling.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
return "anthropic"
def get_llm_invocation_params(
self, context: LLMContext, enable_prompt_caching: bool
) -> AnthropicLLMInvocationParams:
@@ -54,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking Anthropic's LLM API.
"""
messages = self._from_universal_context_messages(self._get_messages(context))
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,
"messages": (
@@ -78,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
List of messages in a format ready for logging about Anthropic.
"""
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -92,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
messages_for_logging.append(msg)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("anthropic")
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""

View File

@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
specific function-calling format, enabling tool use with Nova Sonic models.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.

View File

@@ -42,6 +42,11 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
into AWS Bedrock's expected tool format for function calling capabilities.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
return "aws"
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
@@ -51,7 +56,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking AWS Bedrock's LLM API.
"""
messages = self._from_universal_context_messages(self._get_messages(context))
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,
"messages": messages.messages,
@@ -75,7 +80,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
List of messages in a format ready for logging about AWS Bedrock.
"""
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -89,9 +94,6 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
messages_for_logging.append(msg)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("anthropic")
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""

View File

@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for Google."""
return "google"
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
@@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
Returns:
Dictionary of parameters for Gemini's API.
"""
messages = self._from_universal_context_messages(self._get_messages(context))
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system_instruction": messages.system_instruction,
"messages": messages.messages,
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
List of messages in a format ready for logging about Gemini.
"""
# Get messages in Gemini's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
messages_for_logging.append(obj)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("google")
@dataclass
class ConvertedMessages:
"""Container for Google-formatted messages converted from universal context."""

View File

@@ -24,6 +24,7 @@ from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMContextToolChoice,
LLMSpecificMessage,
NotGiven,
)
@@ -47,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
return "openai"
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
@@ -57,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
Dictionary of parameters for OpenAI's ChatCompletion API.
"""
return {
"messages": self._from_universal_context_messages(self._get_messages(context)),
"messages": self._from_universal_context_messages(self.get_messages(context)),
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools),
"tool_choice": context.tool_choice,
@@ -91,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
List of messages in a format ready for logging about OpenAI.
"""
msgs = []
for message in self._get_messages(context):
for message in self.get_messages(context):
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
@@ -104,14 +110,18 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
msgs.append(msg)
return msgs
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("openai")
def _from_universal_context_messages(
self, messages: List[LLMContextMessage]
) -> List[ChatCompletionMessageParam]:
# Just a pass-through: messages are already the right type
return messages
result = []
for message in messages:
if isinstance(message, LLMSpecificMessage):
# Extract the actual message content from LLMSpecificMessage
result.append(message.message)
else:
# Standard message, pass through unchanged
result.append(message)
return result
def _from_standard_tool_choice(
self, tool_choice: LLMContextToolChoice | NotGiven

View File

@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
OpenAI's Realtime API for function calling capabilities.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.

View File

@@ -115,9 +115,28 @@ class PipelineTask(BasePipelineTask):
- on_frame_reached_downstream: Called when downstream frames reach the sink
- on_idle_timeout: Called when pipeline is idle beyond timeout threshold
- on_pipeline_started: Called when pipeline starts with StartFrame
- on_pipeline_stopped: Called when pipeline stops with StopFrame
- on_pipeline_ended: Called when pipeline ends with EndFrame
- on_pipeline_cancelled: Called when pipeline is cancelled
- on_pipeline_stopped: [deprecated] Called when pipeline stops with StopFrame
.. deprecated:: 0.0.86
Use `on_pipeline_finished` instead.
- on_pipeline_ended: [deprecated] Called when pipeline ends with EndFrame
.. deprecated:: 0.0.86
Use `on_pipeline_finished` instead.
- on_pipeline_cancelled: [deprecated] Called when pipeline is cancelled with CancelFrame
.. deprecated:: 0.0.86
Use `on_pipeline_finished` instead.
- on_pipeline_finished: Called after the pipeline has reached any terminal state.
This includes:
- StopFrame: pipeline was stopped (processors keep connections open)
- EndFrame: pipeline ended normally
- CancelFrame: pipeline was cancelled
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
the frame if they need to handle specific cases.
Example::
@@ -128,6 +147,10 @@ class PipelineTask(BasePipelineTask):
@task.event_handler("on_idle_timeout")
async def on_pipeline_idle_timeout(task):
...
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task, frame):
...
"""
def __init__(
@@ -264,6 +287,7 @@ class PipelineTask(BasePipelineTask):
self._register_event_handler("on_pipeline_stopped")
self._register_event_handler("on_pipeline_ended")
self._register_event_handler("on_pipeline_cancelled")
self._register_event_handler("on_pipeline_finished")
@property
def params(self) -> PipelineParams:
@@ -292,6 +316,27 @@ class PipelineTask(BasePipelineTask):
"""
return self._turn_trace_observer
def event_handler(self, event_name: str):
"""Decorator for registering event handlers.
Args:
event_name: The name of the event to handle.
Returns:
The decorator function that registers the handler.
"""
if event_name in ["on_pipeline_stopped", "on_pipeline_ended", "on_pipeline_cancelled"]:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
f"Event '{event_name}' is deprecated, use 'on_pipeline_finished' instead.",
DeprecationWarning,
)
return super().event_handler(event_name)
def add_observer(self, observer: BaseObserver):
"""Add an observer to monitor pipeline execution.
@@ -534,6 +579,7 @@ class PipelineTask(BasePipelineTask):
)
finally:
await self._call_event_handler("on_pipeline_cancelled", frame)
await self._call_event_handler("on_pipeline_finished", frame)
logger.debug(f"{self}: Closing. Waiting for {frame} to reach the end of the pipeline...")
@@ -681,9 +727,11 @@ class PipelineTask(BasePipelineTask):
self._pipeline_start_event.set()
elif isinstance(frame, EndFrame):
await self._call_event_handler("on_pipeline_ended", frame)
await self._call_event_handler("on_pipeline_finished", frame)
self._pipeline_end_event.set()
elif isinstance(frame, StopFrame):
await self._call_event_handler("on_pipeline_stopped", frame)
await self._call_event_handler("on_pipeline_finished", frame)
self._pipeline_end_event.set()
elif isinstance(frame, CancelFrame):
self._pipeline_end_event.set()

View File

@@ -137,12 +137,12 @@ class AudioBufferProcessor(FrameProcessor):
return self._num_channels
def has_audio(self) -> bool:
"""Check if both user and bot audio buffers contain data.
"""Check if either user or bot audio buffers contain data.
Returns:
True if both buffers contain audio data.
True if either buffer contains audio data.
"""
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
return self._buffer_has_audio(self._user_audio_buffer) or self._buffer_has_audio(
self._bot_audio_buffer
)

View File

@@ -811,60 +811,55 @@ class AWSBedrockLLMService(LLMService):
Returns:
The LLM's response as a string, or None if no response is generated.
"""
try:
messages = []
system = []
if isinstance(context, LLMContext):
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
messages = params["messages"]
system = params["system"] # [{"text": "system message"}]
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) # [{"text": "system message"}]
messages = []
system = []
if isinstance(context, LLMContext):
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
messages = params["messages"]
system = params["system"] # [{"text": "system message"}]
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) # [{"text": "system message"}]
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
# Prepare request parameters
request_params = {
"modelId": model_id,
"messages": messages,
"inferenceConfig": {
"maxTokens": 8192,
"temperature": 0.7,
"topP": 0.9,
},
}
# Prepare request parameters
request_params = {
"modelId": model_id,
"messages": messages,
"inferenceConfig": {
"maxTokens": 8192,
"temperature": 0.7,
"topP": 0.9,
},
}
if system:
request_params["system"] = system
if system:
request_params["system"] = system
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
) as client:
# Call Bedrock without streaming
response = await client.converse(**request_params)
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
) as client:
# Call Bedrock without streaming
response = await client.converse(**request_params)
# Extract the response text
if (
"output" in response
and "message" in response["output"]
and "content" in response["output"]["message"]
):
content = response["output"]["message"]["content"]
if isinstance(content, list):
for item in content:
if item.get("text"):
return item["text"]
elif isinstance(content, str):
return content
# Extract the response text
if (
"output" in response
and "message" in response["output"]
and "content" in response["output"]["message"]
):
content = response["output"]["message"]["content"]
if isinstance(content, list):
for item in content:
if item.get("text"):
return item["text"]
elif isinstance(content, str):
return content
return None
except Exception as e:
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
return None
async def _create_converse_stream(self, client, request_params):

View File

@@ -240,6 +240,7 @@ class HeyGenVideoService(AIService):
# As soon as we receive actual audio, the base output transport will create a
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
await self.stop_ttfb_metrics()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -44,7 +44,7 @@ from pipecat.frames.frames import (
StartFrame,
UserImageRequestFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -195,6 +195,17 @@ class LLMService(AIService):
"""
return self._adapter
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return self.get_llm_adapter().create_llm_specific_message(message)
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.

View File

@@ -0,0 +1,209 @@
# OpenAI Agents SDK Integration
This service integrates the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) with Pipecat, enabling powerful agentic workflows with features like:
- **Agent loops** with tool calling and response streaming
- **Handoffs** between specialized agents
- **Guardrails** for input/output validation
- **Sessions** with automatic conversation history
- **Built-in tracing** and monitoring
## Installation
Install the OpenAI Agents SDK dependency:
```bash
pip install "pipecat-ai[openai-agent]"
# or
uv add "pipecat-ai[openai-agent]"
```
## Basic Usage
```python
from pipecat.services.openai_agent import OpenAIAgentService
# Create a simple agent
agent_service = OpenAIAgentService(
name="Assistant",
instructions="You are a helpful assistant.",
api_key=os.getenv("OPENAI_API_KEY"),
streaming=True,
)
# Use in a pipeline
pipeline = Pipeline([
transport.input(),
stt,
agent_service,
tts,
transport.output(),
])
```
## Features
### Tool Integration
```python
def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny, 22°C"
agent_service = OpenAIAgentService(
name="Weather Assistant",
instructions="Help users with weather information.",
tools=[get_weather],
api_key=os.getenv("OPENAI_API_KEY"),
)
```
### Agent Handoffs
```python
# Create specialized agents
weather_agent = OpenAIAgentService(
name="Weather Specialist",
instructions="Provide weather information and forecasts.",
tools=[get_weather, get_forecast],
)
trivia_agent = OpenAIAgentService(
name="Trivia Master",
instructions="Share interesting facts and trivia.",
tools=[get_random_fact],
)
# Create coordinator that can hand off to specialists
coordinator = OpenAIAgentService(
name="Coordinator",
instructions="Route users to the right specialist.",
handoffs=[weather_agent.agent, trivia_agent.agent],
)
```
### Guardrails
```python
from agents import InputGuardrail, GuardrailFunctionOutput
async def content_filter(ctx, agent, input_data):
# Check input for appropriate content
if is_inappropriate(input_data):
return GuardrailFunctionOutput(
tripwire_triggered=True,
output_info="Content not allowed"
)
return GuardrailFunctionOutput(tripwire_triggered=False)
agent_service = OpenAIAgentService(
name="Safe Assistant",
instructions="You are a helpful and safe assistant.",
input_guardrails=[InputGuardrail(guardrail_function=content_filter)],
)
```
### Session Management
```python
agent_service = OpenAIAgentService(
name="Personal Assistant",
instructions="Remember user preferences and context.",
session_config={
"user_id": "user_123",
"memory_enabled": True,
}
)
# Update session context dynamically
agent_service.update_session_context({
"user_preferences": {"language": "en", "style": "formal"}
})
```
## Configuration Options
### Basic Parameters
- `name`: Agent identifier for handoffs and tracing
- `instructions`: System prompt defining agent behavior
- `api_key`: OpenAI API key (or use `OPENAI_API_KEY` env var)
- `streaming`: Enable real-time token streaming (default: True)
### Advanced Configuration
- `tools`: List of callable functions for the agent to use
- `handoffs`: List of other agents this agent can transfer to
- `input_guardrails`: Input validation and filtering
- `output_guardrails`: Output validation and filtering
- `model_config`: Model settings (model, temperature, etc.)
- `session_config`: Session and memory configuration
### Model Configuration
```python
agent_service = OpenAIAgentService(
name="Precise Assistant",
instructions="Provide accurate, concise responses.",
model_config={
"model": "gpt-4o",
"temperature": 0.1,
"max_tokens": 150,
}
)
```
## Examples
See the foundational examples:
- [`45-openai-agent-basic.py`](../examples/foundational/45-openai-agent-basic.py) - Basic agent with tools
- [`46-openai-agent-handoffs.py`](../examples/foundational/46-openai-agent-handoffs.py) - Multi-agent system with handoffs
## Methods
### Core Methods
- `update_agent_config()` - Update instructions and model settings
- `add_tool()` - Add new tools dynamically
- `add_handoff_agent()` - Add handoff destinations
- `get_session_context()` - Get current session state
- `update_session_context()` - Update session variables
### Lifecycle Methods
Inherited from `AIService`:
- `start()` - Initialize the agent
- `stop()` - Clean up resources
- `cancel()` - Cancel ongoing operations
## Integration with Pipecat
The service processes `TextFrame` inputs and generates:
- `LLMFullResponseStartFrame` - Response beginning
- `LLMTextFrame` - Streaming text tokens (if streaming enabled)
- `LLMFullResponseEndFrame` - Response completion
This integrates seamlessly with Pipecat's conversation pipeline and context aggregators.
## Error Handling
The service includes robust error handling for:
- Missing API keys or SDK installation
- Agent processing failures
- Network connectivity issues
- Malformed tool responses
Errors are emitted as `ErrorFrame` objects in the pipeline.
## Requirements
- OpenAI API key
- `openai-agents` package
- Python 3.10+
## Limitations
- Currently supports OpenAI models only (via Agents SDK)
- Handoffs work within individual requests (no cross-request state)
- Real-time voice features require additional setup

View File

@@ -0,0 +1,11 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI Agents SDK service for Pipecat integration."""
from .agent_service import OpenAIAgentService
__all__ = ["OpenAIAgentService"]

View File

@@ -0,0 +1,567 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI Agents SDK integration service.
Provides integration with the OpenAI Agents SDK for building AI applications
within Pipecat pipelines. This service allows leveraging agent loops, handoffs,
guardrails, sessions, and tools from the OpenAI Agents SDK.
"""
import asyncio
import os
from dataclasses import dataclass
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Protocol,
Sequence,
Union,
override,
runtime_checkable,
)
from loguru import logger
try:
from agents import Agent, InputGuardrail, OutputGuardrail, Runner, Tool
from agents.result import RunResult, RunResultStreaming
from agents.stream_events import StreamEvent
except ImportError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use OpenAI Agents SDK, you need to `pip install openai-agents`. "
"Also, set `OPENAI_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
TextFrame,
UserImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
@runtime_checkable
class ToolLike(Protocol):
"""Protocol for tool-like objects."""
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Tool call interface."""
...
@runtime_checkable
class AgentLike(Protocol):
"""Protocol for agent-like objects."""
name: str
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Agent call interface."""
...
@dataclass
class OpenAIAgentContextAggregatorPair:
"""Pair of OpenAI Agent context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: "OpenAIAgentUserContextAggregator"
_assistant: "OpenAIAgentAssistantContextAggregator"
def user(self) -> "OpenAIAgentUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> "OpenAIAgentAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class OpenAIAgentService(AIService):
"""OpenAI Agents SDK service for Pipecat.
Integrates the OpenAI Agents SDK with Pipecat's pipeline architecture,
enabling advanced agentic workflows with features like handoffs, guardrails,
sessions, and tools within real-time conversational AI applications.
The service processes text input frames and generates streaming responses
using the agent's configured capabilities.
"""
def __init__(
self,
*,
agent: Optional[Agent] = None,
name: str = "Assistant",
instructions: Union[str, Sequence[str]] = "You are a helpful assistant.",
handoffs: Optional[Sequence[AgentLike]] = None,
tools: Optional[Sequence[ToolLike]] = None,
input_guardrails: Optional[Sequence[InputGuardrail]] = None,
output_guardrails: Optional[Sequence[OutputGuardrail]] = None,
model_config: Optional[Dict[str, Any]] = None,
session_config: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
streaming: bool = True,
**kwargs,
):
"""Initialize the OpenAI Agent service.
Args:
agent: Pre-configured Agent instance. If provided, other agent configuration
parameters will be ignored.
name: Name of the agent for identification and handoffs.
instructions: System instructions that define the agent's behavior.
handoffs: List of other agents this agent can hand off to.
tools: List of callable functions the agent can use as tools.
input_guardrails: List of input validation guardrails.
output_guardrails: List of output validation guardrails.
model_config: Configuration for the underlying language model.
session_config: Configuration for session management.
api_key: OpenAI API key. If not provided, will use OPENAI_API_KEY env var.
streaming: Whether to use streaming responses for real-time output.
**kwargs: Additional arguments passed to the parent AIService.
"""
super().__init__(**kwargs)
# Set up API key
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
elif not os.getenv("OPENAI_API_KEY"):
logger.warning("No OpenAI API key provided. Set OPENAI_API_KEY environment variable.")
# Create or use existing agent
if agent:
self._agent = agent
else:
# Convert sequences to lists and handle string instructions
agent_handoffs: List[Any] = list(handoffs) if handoffs else []
agent_tools: List[Any] = list(tools) if tools else []
agent_input_guardrails: List[Any] = list(input_guardrails) if input_guardrails else []
agent_output_guardrails: List[Any] = (
list(output_guardrails) if output_guardrails else []
)
# Handle instructions - convert sequence to string if needed
if isinstance(instructions, str):
agent_instructions = instructions
else:
agent_instructions = " ".join(str(instr) for instr in instructions)
self._agent = Agent(
name=name,
instructions=agent_instructions,
handoffs=agent_handoffs,
tools=agent_tools,
input_guardrails=agent_input_guardrails,
output_guardrails=agent_output_guardrails,
model=model_config.get("model", "gpt-4o") if model_config else "gpt-4o",
)
self._streaming = streaming
self._session_config = session_config or {}
self._current_session = None
self._accumulated_text = ""
# Set model name for metrics
if model_config and "model" in model_config:
self.set_model_name(model_config["model"])
else:
self.set_model_name("gpt-4o") # Default model
logger.info(f"Initialized OpenAI Agent service: {self._agent.name}")
@property
def agent(self) -> Agent:
"""Get the underlying OpenAI Agent.
Returns:
The configured Agent instance.
"""
return self._agent
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIAgentContextAggregatorPair:
"""Create OpenAI-specific context aggregators for agent interactions.
Creates a pair of context aggregators optimized for OpenAI Agent interactions,
including support for function calls, tool usage, and conversation management.
Args:
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
OpenAIAgentContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIAgentContextAggregatorPair.
"""
user = OpenAIAgentUserContextAggregator(context, params=user_params)
assistant = OpenAIAgentAssistantContextAggregator(context, params=assistant_params)
return OpenAIAgentContextAggregatorPair(_user=user, _assistant=assistant)
def update_agent_config(
self,
*,
instructions: Optional[str] = None,
model_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
"""Update agent configuration dynamically.
Args:
instructions: New system instructions for the agent.
model_config: Updated model configuration.
**kwargs: Additional agent configuration parameters.
"""
if instructions:
self._agent.instructions = instructions
logger.info(f"Updated agent instructions for {self._agent.name}")
if model_config:
# Note: OpenAI Agents SDK handles model configuration during agent creation
# We can't update model_config after agent is created, but we can update our model name
if "model" in model_config:
self.set_model_name(model_config["model"])
logger.info(f"Updated model config for {self._agent.name}")
async def start(self, frame: StartFrame):
"""Start the OpenAI Agent service.
Initializes the agent session and prepares for processing.
Args:
frame: The start frame containing initialization parameters.
"""
logger.info(f"Starting OpenAI Agent service: {self._agent.name}")
await super().start(frame)
async def stop(self, frame: EndFrame):
"""Stop the OpenAI Agent service.
Cleans up resources and ends the current session.
Args:
frame: The end frame.
"""
logger.info(f"Stopping OpenAI Agent service: {self._agent.name}")
await super().stop(frame)
async def cancel(self, frame: CancelFrame):
"""Cancel the OpenAI Agent service.
Cancels any ongoing operations.
Args:
frame: The cancel frame.
"""
logger.info(f"Cancelling OpenAI Agent service: {self._agent.name}")
await super().cancel(frame)
@override
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Process frames and handle agent interactions.
Processes OpenAILLMContextFrame and TextFrame by running them through the OpenAI Agent
and streams the results back as LLM frames.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
# Process context frame through the agent
try:
await self.push_frame(LLMFullResponseStartFrame())
# Extract the latest user message from the context
messages = frame.context.get_messages()
if messages:
# Get the last user message
for message in reversed(messages):
if message.get("role") == "user":
content = message.get("content", "")
if isinstance(content, list):
# Extract text from content array
text_parts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text_parts.append(part.get("text", ""))
user_input = " ".join(text_parts)
else:
user_input = str(content)
if user_input.strip():
await self._process_agent_request(user_input)
break
await self.push_frame(LLMFullResponseEndFrame())
except Exception as e:
logger.error(f"Error processing agent context: {e}")
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
elif isinstance(frame, TextFrame):
# Process text input through the agent directly (for backwards compatibility)
try:
await self.push_frame(LLMFullResponseStartFrame())
await self._process_agent_request(frame.text)
await self.push_frame(LLMFullResponseEndFrame())
except Exception as e:
logger.error(f"Error processing agent request: {e}")
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
else:
# For frames we don't handle, pass them through with direction
await self.push_frame(frame, direction)
async def _process_agent_request(self, input_text: str):
"""Process an agent request and stream the results.
Args:
input_text: The user input text to process.
"""
logger.debug(f"Processing agent request: {input_text}")
if self._streaming:
await self._process_streaming_response(input_text)
else:
await self._process_non_streaming_response(input_text)
async def _process_streaming_response(self, input_text: str):
"""Process a streaming agent response.
Args:
input_text: The user input text to process.
"""
try:
# Run the agent with streaming
result: RunResultStreaming = Runner.run_streamed(
self._agent, input_text, context=self._session_config
)
has_streaming_deltas = False
# Process the stream events
async for event in result.stream_events():
if event.type == "raw_response_event":
# Handle token-by-token streaming
# Only check for delta on events that are known to have it
if hasattr(event.data, "delta") and getattr(event.data, "delta", None):
delta_text = getattr(event.data, "delta", "")
if delta_text:
has_streaming_deltas = True
self._accumulated_text += delta_text
await self.push_frame(LLMTextFrame(text=delta_text))
elif event.type == "run_item_stream_event":
# Handle completed items
if event.item.type == "message_output_item":
# Only process complete message if we didn't get streaming deltas
if not has_streaming_deltas:
message_text = self._extract_message_text(event.item)
logger.debug(
f"Processing complete message (no deltas): {message_text[:50]}..."
if len(message_text) > 50
else f"Processing complete message: {message_text}"
)
if message_text:
await self.push_frame(LLMTextFrame(text=message_text))
elif event.item.type == "tool_call_item":
# Use getattr for safe attribute access
tool_name = getattr(event.item, "tool_name", "unknown")
logger.debug(f"Tool called: {tool_name}")
elif event.item.type == "tool_call_output_item":
output = getattr(event.item, "output", "no output")
logger.debug(f"Tool output: {output}")
elif event.type == "agent_updated_stream_event":
logger.debug(f"Agent updated: {event.new_agent.name}")
# Reset accumulated text for next request
self._accumulated_text = ""
except Exception as e:
logger.error(f"Error in streaming response: {e}")
raise
async def _process_non_streaming_response(self, input_text: str):
"""Process a non-streaming agent response.
Args:
input_text: The user input text to process.
"""
try:
# Run the agent without streaming
result: RunResult = await Runner.run(
self._agent, input_text, context=self._session_config
)
# Send the final output
if result.final_output:
await self.push_frame(LLMTextFrame(text=result.final_output))
except Exception as e:
logger.error(f"Error in non-streaming response: {e}")
raise
def _extract_message_text(self, item) -> str:
"""Extract text from a message output item.
Args:
item: The message output item from the agent.
Returns:
The extracted text content.
"""
try:
# Handle OpenAI Agents SDK MessageOutputItem format
if hasattr(item, "raw_item") and hasattr(item.raw_item, "content"):
content = item.raw_item.content
if isinstance(content, list):
text_parts = []
for content_part in content:
if hasattr(content_part, "text"):
text_parts.append(content_part.text)
elif (
isinstance(content_part, dict)
and content_part.get("type") == "output_text"
):
text_parts.append(content_part.get("text", ""))
elif isinstance(content_part, dict) and content_part.get("type") == "text":
text_parts.append(content_part.get("text", ""))
return "".join(text_parts)
elif isinstance(content, str):
return content
# Handle direct content attribute
elif hasattr(item, "content"):
if isinstance(item.content, str):
return item.content
elif isinstance(item.content, list):
# Extract text from content array
text_parts = []
for content_part in item.content:
if isinstance(content_part, dict) and content_part.get("type") == "text":
text_parts.append(content_part.get("text", ""))
elif isinstance(content_part, str):
text_parts.append(content_part)
return "".join(text_parts)
# If no text content found, return empty string instead of str(item)
logger.debug(f"No extractable text content found in item: {type(item)}")
return ""
except Exception as e:
logger.warning(f"Could not extract text from message item: {e}")
return ""
async def add_tool(self, tool_function: ToolLike):
"""Add a tool function to the agent.
Args:
tool_function: A callable function or Tool object to add as a tool.
"""
if hasattr(self._agent, "tools"):
# Cast to Any to handle the type variance issue
tools_list: List[Any] = self._agent.tools
tools_list.append(tool_function)
tool_name = getattr(
tool_function, "__name__", getattr(tool_function, "name", "unknown")
)
logger.info(f"Added tool {tool_name} to agent {self._agent.name}")
async def add_handoff_agent(self, agent: AgentLike):
"""Add a handoff agent.
Args:
agent: Another Agent instance or handoff object that this agent can hand off to.
"""
if hasattr(self._agent, "handoffs"):
# Cast to Any to handle the type variance issue
handoffs_list: List[Any] = self._agent.handoffs
handoffs_list.append(agent)
agent_name = getattr(agent, "name", "unknown")
logger.info(f"Added handoff agent {agent_name} to agent {self._agent.name}")
def get_session_context(self) -> Dict[str, Any]:
"""Get the current session context.
Returns:
Dictionary containing the current session context.
"""
return self._session_config.copy()
def update_session_context(self, context: Dict[str, Any]):
"""Update the session context.
Args:
context: Dictionary of context updates to apply.
"""
self._session_config.update(context)
logger.debug(f"Updated session context for agent {self._agent.name}")
class OpenAIAgentUserContextAggregator(LLMUserContextAggregator):
"""OpenAI Agent-specific user context aggregator.
Handles aggregation of user messages for OpenAI Agent services.
Inherits all functionality from the base LLMUserContextAggregator.
"""
pass
class OpenAIAgentAssistantContextAggregator(LLMAssistantContextAggregator):
"""OpenAI Agent-specific assistant context aggregator.
Handles aggregation of assistant messages for OpenAI Agent services,
with specialized support for OpenAI's function calling format,
tool usage tracking, and agent interaction management.
"""
pass

View File

@@ -478,7 +478,11 @@ class SmallWebRTCClient:
self._screen_video_track = None
self._audio_output_track = None
self._video_output_track = None
await self._callbacks.on_client_disconnected(self._webrtc_connection)
# Trigger `on_client_disconnected` if the client actually disconnects,
# that is, we are not the ones disconnecting.
if not self._closing:
await self._callbacks.on_client_disconnected(self._webrtc_connection)
async def _handle_app_message(self, message: Any):
"""Handle incoming application messages."""

View File

@@ -138,7 +138,6 @@ class FastAPIWebsocketClient:
):
logger.warning("Closing already disconnected websocket!")
self._closing = True
await self.trigger_client_disconnected()
async def disconnect(self):
"""Disconnect the WebSocket client."""
@@ -152,8 +151,6 @@ class FastAPIWebsocketClient:
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception while closing the websocket: {e}")
finally:
await self.trigger_client_disconnected()
async def trigger_client_disconnected(self):
"""Trigger the client disconnected callback."""
@@ -298,7 +295,10 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
await self._client.trigger_client_disconnected()
# Trigger `on_client_disconnected` if the client actually disconnects,
# that is, we are not the ones disconnecting.
if not self._client.is_closing:
await self._client.trigger_client_disconnected()
async def _monitor_websocket(self):
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
@@ -446,6 +446,9 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
async def _write_frame(self, frame: Frame):
"""Serialize and send a frame through the WebSocket."""
if self._client.is_closing or not self._client.is_connected:
return
if not self._params.serializer:
return

172
test_openai_agent.py Normal file
View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""Simple test script for OpenAI Agent service."""
import asyncio
import os
from unittest.mock import MagicMock, patch
# Mock the OpenAI API key for testing
os.environ["OPENAI_API_KEY"] = "test-key-for-testing"
from pipecat.frames.frames import TextFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai_agent import OpenAIAgentService
async def test_basic_functionality():
"""Test basic OpenAI Agent service functionality."""
print("🧪 Testing OpenAI Agent Service...")
# Create a simple weather tool for testing
def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"The weather in {location} is sunny and 22°C."
try:
# Create the service
print("📋 Creating OpenAI Agent service...")
service = OpenAIAgentService(
name="Test Assistant",
instructions="You are a helpful test assistant.",
tools=[get_weather],
api_key="test-key",
streaming=True,
)
print(f"✅ Service created successfully!")
print(f" - Agent name: {service.agent.name}")
print(f" - Model name: {service.model_name}")
print(f" - Streaming enabled: {service._streaming}")
# Test basic configuration
print("⚙️ Testing configuration updates...")
service.update_agent_config(
instructions="Updated test instructions",
model_config={"model": "gpt-4o", "temperature": 0.5},
)
print(f"✅ Configuration updated!")
print(f" - New instructions: {service.agent.instructions}")
print(f" - New model: {service.model_name}")
# Test session context
print("💾 Testing session context...")
service.update_session_context({"user_id": "test-user", "session": "test-session"})
context = service.get_session_context()
print(f"✅ Session context managed!")
print(f" - Context keys: {list(context.keys())}")
# Test adding tools
print("🔧 Testing tool management...")
def get_time() -> str:
"""Get current time."""
return "The current time is 3:00 PM."
await service.add_tool(get_time)
print(f"✅ Tool added successfully!")
print("\n🎉 All basic functionality tests passed!")
return True
except Exception as e:
print(f"❌ Test failed with error: {e}")
return False
async def test_frame_processing():
"""Test frame processing with mocked responses."""
print("\n🔄 Testing frame processing...")
try:
# Mock the Runner to avoid actual API calls
with patch("pipecat.services.openai_agent.agent_service.Runner") as mock_runner:
# Set up mock responses
mock_stream_result = MagicMock()
# Mock stream events
async def mock_stream_events():
# Simulate streaming response
yield MagicMock(type="raw_response_event", data=MagicMock(delta="Hello "))
yield MagicMock(type="raw_response_event", data=MagicMock(delta="from "))
yield MagicMock(type="raw_response_event", data=MagicMock(delta="agent!"))
# Simulate completed message
mock_item = MagicMock()
mock_item.type = "message_output_item"
mock_item.content = "Hello from agent!"
yield MagicMock(type="run_item_stream_event", item=mock_item)
mock_stream_result.stream_events.return_value = mock_stream_events()
mock_runner.run_streamed.return_value = mock_stream_result
# Create service with mocked runner
service = OpenAIAgentService(
name="Test Assistant",
instructions="You are a helpful test assistant.",
api_key="test-key",
streaming=True,
)
# Collect output frames
output_frames = []
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
output_frames.append(frame)
print(f" 📤 Frame: {type(frame).__name__}")
if hasattr(frame, "text"):
print(f" Text: '{frame.text}'")
service.push_frame = mock_push_frame
# Process a text frame
print("📝 Processing text frame...")
text_frame = TextFrame("Hello, how are you?")
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
# Wait for async processing
await asyncio.sleep(0.2)
print(f"✅ Frame processing completed!")
print(f" - Generated {len(output_frames)} output frames")
# Check if we got expected frame types
frame_types = [type(frame).__name__ for frame in output_frames]
print(f" - Frame types: {frame_types}")
return True
except Exception as e:
print(f"❌ Frame processing test failed: {e}")
return False
async def main():
"""Run all tests."""
print("🚀 Starting OpenAI Agent Service Tests\n")
try:
# Run basic functionality tests
basic_test = await test_basic_functionality()
# Run frame processing tests
frame_test = await test_frame_processing()
# Summary
print(f"\n📊 Test Results:")
print(f" - Basic functionality: {'✅ PASS' if basic_test else '❌ FAIL'}")
print(f" - Frame processing: {'✅ PASS' if frame_test else '❌ FAIL'}")
if basic_test and frame_test:
print(f"\n🎉 All tests passed! The OpenAI Agent service is working correctly.")
else:
print(f"\n⚠️ Some tests failed. Please check the output above.")
except Exception as e:
print(f"❌ Test suite failed with error: {e}")
if __name__ == "__main__":
asyncio.run(main())

33
test_simple_agent.py Normal file
View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
import asyncio
import os
from loguru import logger
# Test the actual agents package API
try:
from agents import Agent, run
# Create a simple agent
agent = Agent(
name="test-agent",
instructions="You are a helpful assistant.",
)
print("✅ Agent created successfully!")
print(f"Agent name: {agent.name}")
# Test a simple conversation
async def test_agent():
result = await run(agent, "Hello, how are you?")
print(f"Agent response: {result}")
# Run the test
asyncio.run(test_agent())
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()

View File

@@ -0,0 +1,998 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""
Unit tests for LLM adapters' get_llm_invocation_params() method.
These tests focus specifically on the "messages" field generation for different adapters, ensuring:
For OpenAI adapter:
1. LLMStandardMessage objects are passed through unchanged
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
3. Complex message structures (like multi-part content) are preserved
4. System instructions are preserved throughout messages at any position
For Gemini adapter:
1. LLMStandardMessage objects are converted to Gemini Content format
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
4. System messages are extracted as system_instruction (without duplication)
5. Single system instruction is converted to user message when no other messages exist
6. Multiple system instructions: first extracted, later ones converted to user messages
For Anthropic adapter:
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
4. System messages: first extracted as system parameter, later ones converted to user messages
5. Consecutive messages with same role are merged into multi-content-block messages
6. Empty text content is converted to "(empty)"
For AWS Bedrock adapter:
1. LLMStandardMessage objects are converted to AWS Bedrock format
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
4. System messages: first extracted as system parameter, later ones converted to user messages
5. Consecutive messages with same role are merged into multi-content-block messages
6. Empty text content is converted to "(empty)"
"""
import unittest
from google.genai.types import Content, Part
from openai.types.chat import ChatCompletionMessage
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMSpecificMessage,
LLMStandardMessage,
)
class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = OpenAILLMAdapter()
def test_standard_messages_passed_through_unchanged(self):
"""Test that LLMStandardMessage objects are passed through unchanged to OpenAI params."""
# Create standard messages (OpenAI format)
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
]
# Create context with these messages
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify messages are passed through unchanged
self.assertEqual(params["messages"], standard_messages)
self.assertEqual(len(params["messages"]), 3)
# Verify content matches exactly
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
def test_llm_specific_message_filtering(self):
"""Test that OpenAI-specific messages are included and others are filtered out."""
# Create messages with different LLM-specific ones
messages = [
{"role": "system", "content": "You are a helpful assistant."},
AnthropicLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Anthropic specific message"}
),
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Gemini specific message"}
),
{"role": "user", "content": "Standard user message"},
self.adapter.create_llm_specific_message(
{"role": "assistant", "content": "OpenAI specific response"}
),
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only include standard messages and OpenAI-specific ones
# (3 total: system, standard user, openai assistant)
self.assertEqual(len(params["messages"]), 3)
# Verify the correct messages are included
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
self.assertEqual(params["messages"][1]["content"], "Standard user message")
self.assertEqual(
params["messages"][2], {"role": "assistant", "content": "OpenAI specific response"}
)
def test_complex_message_content_preserved(self):
"""Test that complex message content (like multi-part messages) is preserved."""
# Create a message with complex content structure (text + image)
complex_image_message = {
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD..."},
},
],
}
# Create a message with multiple text blocks
multi_text_message = {
"role": "assistant",
"content": [
{"type": "text", "text": "Let me analyze this step by step:"},
{"type": "text", "text": "1. First, I'll examine the visual elements"},
{"type": "text", "text": "2. Then I'll provide my conclusions"},
],
}
messages = [
{"role": "system", "content": "You are a helpful assistant that can analyze images."},
complex_image_message,
multi_text_message,
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify complex content is preserved
self.assertEqual(len(params["messages"]), 3)
self.assertEqual(params["messages"][1], complex_image_message)
self.assertEqual(params["messages"][2], multi_text_message)
# Verify the image message structure is maintained
image_content = params["messages"][1]["content"]
self.assertIsInstance(image_content, list)
self.assertEqual(len(image_content), 2)
self.assertEqual(image_content[0]["type"], "text")
self.assertEqual(image_content[1]["type"], "image_url")
# Verify the multi-text message structure is maintained
text_content = params["messages"][2]["content"]
self.assertIsInstance(text_content, list)
self.assertEqual(len(text_content), 3)
for i, text_block in enumerate(text_content):
self.assertEqual(text_block["type"], "text")
self.assertEqual(text_content[0]["text"], "Let me analyze this step by step:")
self.assertEqual(text_content[1]["text"], "1. First, I'll examine the visual elements")
self.assertEqual(text_content[2]["text"], "2. Then I'll provide my conclusions")
def test_system_instructions_preserved_throughout_messages(self):
"""Test that OpenAI adapter preserves system instructions sprinkled throughout messages."""
# Create messages with system instructions at different positions
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."},
{"role": "user", "content": "Tell me about Python."},
{"role": "system", "content": "Use simple language."},
{"role": "assistant", "content": "Python is a programming language."},
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# OpenAI should preserve all messages unchanged, including multiple system messages
self.assertEqual(len(params["messages"]), 7)
# Verify system messages are preserved at their original positions
self.assertEqual(params["messages"][0]["role"], "system")
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
self.assertEqual(params["messages"][3]["role"], "system")
self.assertEqual(params["messages"][3]["content"], "Remember to be concise.")
self.assertEqual(params["messages"][5]["role"], "system")
self.assertEqual(params["messages"][5]["content"], "Use simple language.")
# Verify other messages remain unchanged
self.assertEqual(params["messages"][1]["role"], "user")
self.assertEqual(params["messages"][2]["role"], "assistant")
self.assertEqual(params["messages"][4]["role"], "user")
self.assertEqual(params["messages"][6]["role"], "assistant")
class TestGeminiGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = GeminiLLMAdapter()
def test_standard_messages_converted_to_gemini_format(self):
"""Test that LLMStandardMessage objects are converted to Gemini Content format."""
# Create standard messages (OpenAI format)
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
]
# Create context with these messages
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction is extracted
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Verify messages are converted to Gemini format (2 messages: user + model)
self.assertEqual(len(params["messages"]), 2)
# Check first message (user)
user_msg = params["messages"][0]
self.assertIsInstance(user_msg, Content)
self.assertEqual(user_msg.role, "user")
self.assertEqual(len(user_msg.parts), 1)
self.assertEqual(user_msg.parts[0].text, "Hello, how are you?")
# Check second message (assistant -> model)
model_msg = params["messages"][1]
self.assertIsInstance(model_msg, Content)
self.assertEqual(model_msg.role, "model")
self.assertEqual(len(model_msg.parts), 1)
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
def test_llm_specific_message_filtering(self):
"""Test that Gemini-specific messages are included and others are filtered out."""
# Create messages with different LLM-specific ones
messages = [
{"role": "system", "content": "You are a helpful assistant."},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific message"}
),
AnthropicLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Anthropic specific message"}
),
{"role": "user", "content": "Standard user message"},
self.adapter.create_llm_specific_message(
Content(role="model", parts=[Part(text="Gemini specific response")]),
),
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only include standard messages and Gemini-specific ones
# (2 total: converted standard user + gemini model)
self.assertEqual(len(params["messages"]), 2)
# Verify system instruction
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Verify the correct messages are included
self.assertEqual(params["messages"][0].role, "user")
self.assertEqual(params["messages"][0].parts[0].text, "Standard user message")
self.assertEqual(params["messages"][1].role, "model")
self.assertEqual(params["messages"][1].parts[0].text, "Gemini specific response")
def test_complex_message_content_preserved(self):
"""Test that complex message content (like multi-part messages) is preserved and converted.
This test covers image, audio, and multi-text content conversion to Gemini format.
"""
# Create a message with complex content structure (text + image)
# Using a minimal valid base64 image data
complex_image_message = {
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
},
},
],
}
# Create a message with multiple text blocks
multi_text_message = {
"role": "assistant",
"content": [
{"type": "text", "text": "Let me analyze this step by step:"},
{"type": "text", "text": "1. First, I'll examine the visual elements"},
{"type": "text", "text": "2. Then I'll provide my conclusions"},
],
}
# Create a message with audio input (text + audio)
# Using a minimal valid base64 audio data (16 bytes of WAV header)
audio_message = {
"role": "user",
"content": [
{"type": "text", "text": "Can you transcribe this audio?"},
{
"type": "input_audio",
"input_audio": {
"data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=",
"format": "wav",
},
},
],
}
messages = [
{
"role": "system",
"content": "You are a helpful assistant that can analyze images and audio.",
},
complex_image_message,
multi_text_message,
audio_message,
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction
self.assertEqual(
params["system_instruction"],
"You are a helpful assistant that can analyze images and audio.",
)
# Verify complex content is converted to Gemini format
# Note: Gemini adapter may add system instruction back as user message in some cases
self.assertGreaterEqual(len(params["messages"]), 3)
# Find the different message types
user_with_image = None
model_with_text = None
user_with_audio = None
for msg in params["messages"]:
if msg.role == "user" and len(msg.parts) == 2:
# Check if it's image or audio based on the text content
if hasattr(msg.parts[0], "text") and "image" in msg.parts[0].text:
user_with_image = msg
elif hasattr(msg.parts[0], "text") and "audio" in msg.parts[0].text:
user_with_audio = msg
elif msg.role == "model" and len(msg.parts) == 3:
model_with_text = msg
# Verify the image message structure is converted properly
self.assertIsNotNone(user_with_image, "Should have user message with image")
self.assertEqual(len(user_with_image.parts), 2)
# First part should be text
self.assertEqual(user_with_image.parts[0].text, "What's in this image?")
# Second part should be image data (converted to Blob)
self.assertIsNotNone(user_with_image.parts[1].inline_data)
self.assertEqual(user_with_image.parts[1].inline_data.mime_type, "image/jpeg")
# Verify the audio message structure is converted properly
self.assertIsNotNone(user_with_audio, "Should have user message with audio")
self.assertEqual(len(user_with_audio.parts), 2)
# First part should be text
self.assertEqual(user_with_audio.parts[0].text, "Can you transcribe this audio?")
# Second part should be audio data (converted to Blob)
self.assertIsNotNone(user_with_audio.parts[1].inline_data)
self.assertEqual(user_with_audio.parts[1].inline_data.mime_type, "audio/wav")
# Verify the multi-text message structure is converted properly
self.assertIsNotNone(model_with_text, "Should have model message with multi-text")
self.assertEqual(len(model_with_text.parts), 3)
# All parts should be text
expected_texts = [
"Let me analyze this step by step:",
"1. First, I'll examine the visual elements",
"2. Then I'll provide my conclusions",
]
for i, expected_text in enumerate(expected_texts):
self.assertEqual(model_with_text.parts[i].text, expected_text)
def test_single_system_instruction_converted_to_user(self):
"""Test that when there's only a system instruction, it gets converted to user message."""
# Create context with only a system message
messages = [
{"role": "system", "content": "You are a helpful assistant."},
]
context = LLMContext(messages=messages)
params = self.adapter.get_llm_invocation_params(context)
# System instruction should be extracted
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# But since there are no other messages, it should also be added back as a user message
self.assertEqual(len(params["messages"]), 1)
self.assertEqual(params["messages"][0].role, "user")
self.assertEqual(params["messages"][0].parts[0].text, "You are a helpful assistant.")
def test_multiple_system_instructions_handling(self):
"""Test that first system instruction is extracted, later ones converted to user messages."""
# Create messages with multiple system instructions
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."},
{"role": "user", "content": "Tell me about Python."},
{"role": "system", "content": "Use simple language."},
{"role": "assistant", "content": "Python is a programming language."},
]
context = LLMContext(messages=messages)
params = self.adapter.get_llm_invocation_params(context)
# First system instruction should be extracted
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Should have 6 messages (original 7 minus 1 system instruction that was extracted)
self.assertEqual(len(params["messages"]), 6)
# Find the converted system messages (should be user role now)
converted_system_messages = []
for msg in params["messages"]:
if msg.role == "user" and (
msg.parts[0].text == "Remember to be concise."
or msg.parts[0].text == "Use simple language."
):
converted_system_messages.append(msg.parts[0].text)
# Should have 2 converted system messages
self.assertEqual(len(converted_system_messages), 2)
self.assertIn("Remember to be concise.", converted_system_messages)
self.assertIn("Use simple language.", converted_system_messages)
# Verify that regular user and assistant messages are preserved
user_messages = [msg for msg in params["messages"] if msg.role == "user"]
model_messages = [msg for msg in params["messages"] if msg.role == "model"]
# Should have 4 user messages: 2 original + 2 converted from system
self.assertEqual(len(user_messages), 4)
# Should have 2 model messages (converted from assistant)
self.assertEqual(len(model_messages), 2)
class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = AnthropicLLMAdapter()
def test_standard_messages_converted_to_anthropic_format(self):
"""Test that LLMStandardMessage objects are converted to Anthropic MessageParam format."""
# Create standard messages
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
]
# Create context
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Verify system instruction is extracted
self.assertEqual(params["system"], "You are a helpful assistant.")
# Verify messages are in the params (2 messages after system extraction)
self.assertIn("messages", params)
self.assertEqual(len(params["messages"]), 2)
# Check first message (user)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertEqual(user_msg["content"], "Hello, how are you?")
# Check second message (assistant)
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
def test_llm_specific_message_filtering(self):
"""Test that Anthropic-specific messages are included and others are filtered out."""
# Create anthropic-specific message content
anthropic_message_content = {
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/jpeg", "data": "fake_data"},
},
],
}
messages = [
{"role": "user", "content": "Standard message"},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific"}
),
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Google specific"}
),
self.adapter.create_llm_specific_message(anthropic_message_content),
{"role": "assistant", "content": "Response"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
# (openai and google specific filtered out, standard + anthropic-specific merged)
self.assertEqual(len(params["messages"]), 2)
# First message: merged user message (standard + anthropic-specific)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
# Should have 3 content blocks: standard text + anthropic text + anthropic image
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["type"], "text")
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "Hello")
self.assertEqual(user_msg["content"][2]["type"], "image")
# Second message: standard response
self.assertEqual(params["messages"][1]["content"], "Response")
def test_consecutive_same_role_messages_merged(self):
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
messages = [
{"role": "user", "content": "First user message"},
{"role": "user", "content": "Second user message"},
{"role": "user", "content": "Third user message"},
{"role": "assistant", "content": "First assistant message"},
{"role": "assistant", "content": "Second assistant message"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Should have 2 messages after merging (1 user, 1 assistant)
self.assertEqual(len(params["messages"]), 2)
# Check merged user message
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["type"], "text")
self.assertEqual(user_msg["content"][0]["text"], "First user message")
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
self.assertEqual(user_msg["content"][2]["type"], "text")
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
# Check merged assistant message
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(len(assistant_msg["content"]), 2)
self.assertEqual(assistant_msg["content"][0]["type"], "text")
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
self.assertEqual(assistant_msg["content"][1]["type"], "text")
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
def test_empty_text_converted_to_empty_placeholder(self):
"""Test that empty text content is converted to "(empty)" string."""
messages = [
{"role": "user", "content": ""}, # Empty string
{
"role": "assistant",
"content": [
{"type": "text", "text": ""}, # Empty text in list content
{"type": "text", "text": "Valid text"},
],
},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Check that empty string content was converted
user_msg = params["messages"][0]
self.assertEqual(user_msg["content"], "(empty)")
# Check that empty text in list content was converted
assistant_msg = params["messages"][1]
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
def test_complex_message_content_preserved(self):
"""Test that complex message structures (text + image) are properly converted to Anthropic format."""
# Create a complex message with both text and image content
complex_message = {
"role": "user",
"content": [
{"type": "text", "text": "What do you see in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,fake_image_data"},
},
{"type": "text", "text": "Please describe it in detail."},
],
}
messages = [
complex_message,
{"role": "assistant", "content": "I can see the image clearly."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Verify complex message structure is preserved and converted
self.assertEqual(len(params["messages"]), 2)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
# Note: Anthropic adapter reorders single images to come before text, as per Anthropic docs
# Check image part (should be moved to first position and converted from image_url to image)
self.assertEqual(user_msg["content"][0]["type"], "image")
self.assertIn("source", user_msg["content"][0])
self.assertEqual(user_msg["content"][0]["source"]["type"], "base64")
self.assertEqual(user_msg["content"][0]["source"]["media_type"], "image/jpeg")
self.assertEqual(user_msg["content"][0]["source"]["data"], "fake_image_data")
# Check first text part (moved to second position)
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
# Check second text part (moved to third position)
self.assertEqual(user_msg["content"][2]["type"], "text")
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
def test_multiple_system_instructions_handling(self):
"""Test that first system instruction is extracted, later ones converted to user messages."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."}, # Later system message
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# System instruction should be extracted from first message
self.assertEqual(params["system"], "You are a helpful assistant.")
# Should have 3 messages remaining (system message was removed, later system converted to user)
self.assertEqual(len(params["messages"]), 3)
self.assertEqual(params["messages"][0]["role"], "user")
self.assertEqual(params["messages"][0]["content"], "Hello")
self.assertEqual(params["messages"][1]["role"], "assistant")
self.assertEqual(params["messages"][1]["content"], "Hi there!")
# Later system message should be converted to user role
self.assertEqual(params["messages"][2]["role"], "user")
self.assertEqual(params["messages"][2]["content"], "Remember to be concise.")
def test_single_system_message_converted_to_user(self):
"""Test that a single system message is converted to user role when no other messages exist."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# System should be NOT_GIVEN since we only have one message
from anthropic import NOT_GIVEN
self.assertEqual(params["system"], NOT_GIVEN)
# Single system message should be converted to user role
self.assertEqual(len(params["messages"]), 1)
self.assertEqual(params["messages"][0]["role"], "user")
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
def setUp(self) -> None:
"""Sets up a common adapter instance for all tests."""
self.adapter = AWSBedrockLLMAdapter()
def test_standard_messages_converted_to_aws_bedrock_format(self):
"""Test that LLMStandardMessage objects are converted to AWS Bedrock format."""
# Create standard messages
standard_messages: list[LLMStandardMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
]
# Create context
context = LLMContext(messages=standard_messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction is extracted (in AWS Bedrock format)
self.assertIsInstance(params["system"], list)
self.assertEqual(len(params["system"]), 1)
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
# Verify messages are in the params (2 messages after system extraction)
self.assertIn("messages", params)
self.assertEqual(len(params["messages"]), 2)
# Check first message (user) - should be converted to AWS Bedrock format
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 1)
self.assertEqual(user_msg["content"][0]["text"], "Hello, how are you?")
# Check second message (assistant) - should be converted to AWS Bedrock format
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(len(assistant_msg["content"]), 1)
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
def test_llm_specific_message_filtering(self):
"""Test that AWS-specific messages are included and others are filtered out."""
# Create aws-specific message content (which is what AWS Bedrock uses)
aws_message_content = {
"role": "user",
"content": [
{"text": "Hello"},
{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}},
],
}
messages = [
{"role": "user", "content": "Standard message"},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific"}
),
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Google specific"}
),
self.adapter.create_llm_specific_message(message=aws_message_content),
{"role": "assistant", "content": "Response"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
# (openai and google specific filtered out, standard + aws-specific merged)
self.assertEqual(len(params["messages"]), 2)
# First message: merged user message (standard + aws-specific)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
# Should have 3 content blocks: standard text + aws text + aws image
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
self.assertEqual(user_msg["content"][1]["text"], "Hello")
self.assertIn("image", user_msg["content"][2])
# Second message: standard response
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
def test_consecutive_same_role_messages_merged(self):
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
messages = [
{"role": "user", "content": "First user message"},
{"role": "user", "content": "Second user message"},
{"role": "user", "content": "Third user message"},
{"role": "assistant", "content": "First assistant message"},
{"role": "assistant", "content": "Second assistant message"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should have 2 messages after merging (1 user, 1 assistant)
self.assertEqual(len(params["messages"]), 2)
# Check merged user message
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["text"], "First user message")
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
# Check merged assistant message
assistant_msg = params["messages"][1]
self.assertEqual(assistant_msg["role"], "assistant")
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(len(assistant_msg["content"]), 2)
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
def test_empty_text_converted_to_empty_placeholder(self):
"""Test that empty text content is converted to "(empty)" string."""
messages = [
{"role": "user", "content": ""}, # Empty string
{
"role": "assistant",
"content": [
{"type": "text", "text": ""}, # Empty text in list content
{"type": "text", "text": "Valid text"},
],
},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Check that empty string content was converted
user_msg = params["messages"][0]
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(user_msg["content"][0]["text"], "(empty)")
# Check that empty text in list content was converted
assistant_msg = params["messages"][1]
self.assertIsInstance(assistant_msg["content"], list)
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
def test_complex_message_content_preserved(self):
"""Test that complex message structures (text + image) are properly converted to AWS Bedrock format."""
# Create a complex message with both text and image content
# Use a valid base64 string for the image
complex_message = {
"role": "user",
"content": [
{"type": "text", "text": "What do you see in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
},
},
{"type": "text", "text": "Please describe it in detail."},
],
}
messages = [
complex_message,
{"role": "assistant", "content": "I can see the image clearly."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify complex message structure is preserved and converted
self.assertEqual(len(params["messages"]), 2)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
self.assertEqual(len(user_msg["content"]), 3)
# Note: AWS Bedrock adapter reorders single images to come before text, like Anthropic
# Check image part (should be moved to first position and converted from image_url to image)
self.assertIn("image", user_msg["content"][0])
self.assertEqual(user_msg["content"][0]["image"]["format"], "jpeg")
self.assertIn("source", user_msg["content"][0]["image"])
self.assertIn("bytes", user_msg["content"][0]["image"]["source"])
# Check first text part (moved to second position)
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
# Check second text part (moved to third position)
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
def test_multiple_system_instructions_handling(self):
"""Test that first system instruction is extracted, later ones converted to user messages."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "system", "content": "Remember to be concise."}, # Later system message
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# System instruction should be extracted from first message (in AWS Bedrock format)
self.assertIsInstance(params["system"], list)
self.assertEqual(len(params["system"]), 1)
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
# Should have 3 messages remaining (system message was removed, later system converted to user)
self.assertEqual(len(params["messages"]), 3)
self.assertEqual(params["messages"][0]["role"], "user")
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
self.assertEqual(params["messages"][1]["role"], "assistant")
self.assertEqual(params["messages"][1]["content"][0]["text"], "Hi there!")
# Later system message should be converted to user role
self.assertEqual(params["messages"][2]["role"], "user")
self.assertEqual(params["messages"][2]["content"][0]["text"], "Remember to be concise.")
def test_single_system_message_handling(self):
"""Test that a single system message is extracted as system parameter and no messages remain."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# System should be extracted (in AWS Bedrock format)
self.assertIsInstance(params["system"], list)
self.assertEqual(len(params["system"]), 1)
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
# No messages should remain after system extraction
self.assertEqual(len(params["messages"]), 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,286 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for OpenAI Agent service."""
import asyncio
import os
import sys
import unittest.mock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add src to path for testing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
TextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
class MockAgent:
"""Mock Agent for testing."""
def __init__(self, name="Test Agent", instructions="Test instructions"):
self.name = name
self.instructions = instructions
self.tools = []
self.handoffs = []
class MockRunResult:
"""Mock RunResult for testing."""
def __init__(self, final_output="Test response"):
self.final_output = final_output
class MockStreamEvent:
"""Mock StreamEvent for testing."""
def __init__(self, event_type, data=None, item=None):
self.type = event_type
self.data = data
self.item = item
class MockMessageItem:
"""Mock message item for testing."""
def __init__(self, content="Test content"):
self.type = "message_output_item"
self.content = content
class MockRunner:
"""Mock Runner for testing."""
@staticmethod
async def run(agent, input_text, context=None):
return MockRunResult("Mocked response")
@staticmethod
def run_streamed(agent, input_text, context=None):
class MockStreamResult:
async def stream_events(self):
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="Test "))
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="response"))
yield MockStreamEvent(
"run_item_stream_event", item=MockMessageItem("Test response")
)
return MockStreamResult()
@pytest.fixture
def mock_openai_agents():
"""Mock the OpenAI Agents SDK imports."""
with patch.dict(
"sys.modules",
{
"agents": MagicMock(),
"agents.stream_events": MagicMock(),
"agents.result": MagicMock(),
},
):
# Mock the classes and functions we need
mock_agent = MagicMock()
mock_agent.return_value = MockAgent()
mock_runner = MagicMock()
mock_runner.run = AsyncMock(return_value=MockRunResult())
mock_runner.run_streamed = MagicMock(return_value=MockRunner.run_streamed(None, None))
with (
patch("pipecat.services.openai_agent.agent_service.Agent", mock_agent),
patch("pipecat.services.openai_agent.agent_service.Runner", mock_runner),
):
yield {
"Agent": mock_agent,
"Runner": mock_runner,
}
@pytest.mark.asyncio
async def test_openai_agent_service_init(mock_openai_agents):
"""Test OpenAI Agent service initialization."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
)
assert service.agent.name == "Test Agent"
assert service._streaming is True
@pytest.mark.asyncio
async def test_openai_agent_service_process_text_frame_streaming(mock_openai_agents):
"""Test processing text frame with streaming enabled."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
)
# Mock the push_frame method to capture output
output_frames = []
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
output_frames.append(frame)
service.push_frame = mock_push_frame
# Process a text frame
text_frame = TextFrame("Hello, agent!")
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
# Wait a bit for async processing
await asyncio.sleep(0.1)
# Check that appropriate frames were generated
assert len(output_frames) > 0
assert any(isinstance(frame, LLMFullResponseStartFrame) for frame in output_frames)
@pytest.mark.asyncio
async def test_openai_agent_service_process_text_frame_non_streaming(mock_openai_agents):
"""Test processing text frame with streaming disabled."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=False
)
# Mock the push_frame method to capture output
output_frames = []
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
output_frames.append(frame)
service.push_frame = mock_push_frame
# Process a text frame
text_frame = TextFrame("Hello, agent!")
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
# Wait a bit for async processing
await asyncio.sleep(0.1)
# Check that appropriate frames were generated
assert len(output_frames) > 0
@pytest.mark.asyncio
async def test_openai_agent_service_update_config(mock_openai_agents):
"""Test updating agent configuration."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key"
)
# Update configuration
service.update_agent_config(
instructions="Updated instructions", model_config={"model": "gpt-4o", "temperature": 0.7}
)
assert service.agent.instructions == "Updated instructions"
assert service.agent.model_config["model"] == "gpt-4o"
@pytest.mark.asyncio
async def test_openai_agent_service_session_context(mock_openai_agents):
"""Test session context management."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent",
instructions="Test instructions",
api_key="test-key",
session_config={"user_id": "test-user"},
)
# Get initial context
context = service.get_session_context()
assert context["user_id"] == "test-user"
# Update context
service.update_session_context({"session_id": "test-session"})
updated_context = service.get_session_context()
assert updated_context["user_id"] == "test-user"
assert updated_context["session_id"] == "test-session"
@pytest.mark.asyncio
async def test_openai_agent_service_add_tools(mock_openai_agents):
"""Test adding tools to the agent."""
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key"
)
# Define a test tool
def test_tool():
return "test result"
# Add the tool
await service.add_tool(test_tool)
# Check if tool was added (this depends on the mock implementation)
assert hasattr(service.agent, "tools")
@pytest.mark.asyncio
async def test_openai_agent_service_lifecycle(mock_openai_agents):
"""Test service lifecycle methods."""
from pipecat.frames.frames import CancelFrame, EndFrame, StartFrame
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
service = OpenAIAgentService(
name="Test Agent", instructions="Test instructions", api_key="test-key"
)
# Test start
start_frame = StartFrame()
await service.start(start_frame)
# Test cancel
cancel_frame = CancelFrame()
await service.cancel(cancel_frame)
# Test stop
end_frame = EndFrame()
await service.stop(end_frame)
def test_openai_agent_service_import_error():
"""Test that import error is handled gracefully."""
# Mock the import to fail
with patch.dict("sys.modules", {"agents": None}):
with pytest.raises(Exception) as exc_info:
# This should trigger the import error
import importlib
import pipecat.services.openai_agent.agent_service
importlib.reload(pipecat.services.openai_agent.agent_service)
assert "Missing module" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -196,10 +196,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
nonlocal start_received
start_received = True
@task.event_handler("on_pipeline_ended")
async def on_pipeline_ended(task, frame: EndFrame):
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task, frame: Frame):
nonlocal end_received
end_received = True
end_received = isinstance(frame, EndFrame)
await task.queue_frame(EndFrame())
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
@@ -214,10 +214,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
pipeline = Pipeline([identity])
task = PipelineTask(pipeline)
@task.event_handler("on_pipeline_stopped")
async def on_pipeline_ended(task, frame: StopFrame):
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task, frame: Frame):
nonlocal stop_received
stop_received = True
stop_received = isinstance(frame, StopFrame)
await task.queue_frame(StopFrame())
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
@@ -441,10 +441,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
async def on_pipeline_started(task: PipelineTask, frame: StartFrame):
await task.cancel()
@task.event_handler("on_pipeline_cancelled")
async def on_pipeline_cancelled(task: PipelineTask, frame: CancelFrame):
@task.event_handler("on_pipeline_finished")
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
nonlocal cancelled
cancelled = True
cancelled = isinstance(frame, CancelFrame)
try:
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))

261
tests/test_run_inference.py Normal file
View File

@@ -0,0 +1,261 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from anthropic import NOT_GIVEN
from openai import NotGiven
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
@pytest.mark.asyncio
async def test_openai_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response."""
# Create service with mocked client
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
]
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello! How can I help you today?"
service._client.chat.completions.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.chat.completions.create.assert_called_once_with(
model="gpt-4",
messages=test_messages,
stream=False,
)
@pytest.mark.asyncio
async def test_openai_run_inference_client_exception():
"""Test that exceptions from the client are propagated."""
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.chat.completions.create.side_effect = Exception("API Error")
with pytest.raises(Exception, match="API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_anthropic_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Anthropic."""
# Create service with mocked client
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=test_messages, system=test_system, tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Hello! How can I help you today?"
service._client.messages.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(
mock_context, enable_prompt_caching=False
)
service._client.messages.create.assert_called_once_with(
model="claude-3-sonnet-20240229",
messages=test_messages,
system=test_system,
max_tokens=8192,
stream=False,
)
@pytest.mark.asyncio
async def test_anthropic_run_inference_client_exception():
"""Test that exceptions from the Anthropic client are propagated."""
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=[], system="Test system", tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.messages.create.side_effect = Exception("Anthropic API Error")
with pytest.raises(Exception, match="Anthropic API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_google_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Google."""
# Create service with mocked client
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=test_messages, system_instruction=test_system, tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content = MagicMock()
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_google_run_inference_client_exception():
"""Test that exceptions from the Google client are propagated."""
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=[], system_instruction="Test system", tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(
side_effect=Exception("Google API Error")
)
with pytest.raises(Exception, match="Google API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
# Create service and patch the session client method
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
test_system = [{"text": "You are a helpful assistant"}]
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=test_messages, system=test_system, tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock the client and response
mock_client = AsyncMock()
mock_response = {
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
}
mock_client.converse.return_value = mock_response
# Patch the _aws_session.client method to be an async context manager
async def mock_client_cm(*args, **kwargs):
return mock_client
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
mock_client.converse.assert_called_once()
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_client_exception():
"""Test that exceptions from the AWS Bedrock client are propagated."""
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock AWS client to raise exception
mock_client = AsyncMock()
mock_client.converse.side_effect = Exception("Bedrock API Error")
# Patch the _aws_session.client method to be an async context manager
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
with pytest.raises(Exception, match="Bedrock API Error"):
await service.run_inference(mock_context)

7851
uv.lock generated

File diff suppressed because it is too large Load Diff