Compare commits
25 Commits
v0.0.85
...
hush/openA
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29d4a56663 | ||
|
|
373a09ecd6 | ||
|
|
07f54c48f3 | ||
|
|
c8a3d65aa4 | ||
|
|
50a2a0dc86 | ||
|
|
0421d97954 | ||
|
|
54c8f336c3 | ||
|
|
b086fbafe6 | ||
|
|
cca90791c4 | ||
|
|
f2a5d408de | ||
|
|
044c6eba46 | ||
|
|
db71089f5e | ||
|
|
f861f5066f | ||
|
|
81cede2c60 | ||
|
|
7603203230 | ||
|
|
8569b61598 | ||
|
|
fe42187dc1 | ||
|
|
999e88c942 | ||
|
|
c04df2f28b | ||
|
|
100ef0ab5c | ||
|
|
42886d7105 | ||
|
|
22cbba002a | ||
|
|
c873798ce5 | ||
|
|
786387722a | ||
|
|
9f82c6b4a4 |
285
AGENTS.md
Normal file
285
AGENTS.md
Normal file
@@ -0,0 +1,285 @@
|
||||
# AGENTS.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
Pipecat is an open-source Python framework for building real-time voice and multimodal conversational AI agents. The codebase is organized around a pipeline architecture where data flows through connected services (STT → LLM → TTS).
|
||||
|
||||
## Development Environment Setup
|
||||
|
||||
### Prerequisites
|
||||
- **Minimum Python Version:** 3.10
|
||||
- **Recommended Python Version:** 3.12
|
||||
- **Package Manager:** uv (recommended) or pip
|
||||
|
||||
### Setup Commands
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/pipecat-ai/pipecat.git
|
||||
cd pipecat
|
||||
|
||||
# Install dependencies with uv (recommended)
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra ultravox
|
||||
|
||||
# Or with pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Install pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Set up environment variables
|
||||
cp env.example .env
|
||||
```
|
||||
|
||||
## Build and Test Commands
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest tests/test_name.py
|
||||
|
||||
# Run tests with coverage
|
||||
uv run pytest --cov=pipecat --cov-report=html
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Format code (required before commits)
|
||||
uv run ruff format
|
||||
|
||||
# Lint code
|
||||
uv run ruff check
|
||||
|
||||
# Type checking
|
||||
uv run mypy src/pipecat
|
||||
|
||||
# Run pre-commit checks manually
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Documentation
|
||||
```bash
|
||||
# Build API documentation
|
||||
cd docs/api
|
||||
./build-docs.sh
|
||||
|
||||
# Build docs manually
|
||||
sphinx-build -b html . _build/html -W --keep-going
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Python Standards
|
||||
- **Formatting:** Strict PEP 8 via Ruff
|
||||
- **Docstrings:** Google-style format
|
||||
- **Type Hints:** Required for all public APIs
|
||||
- **Import Organization:** Automated via Ruff
|
||||
|
||||
### Docstring Conventions
|
||||
- **Classes:** Describe purpose + `__init__` with complete `Args:` section
|
||||
- **Dataclasses:** Use `Parameters:` section, no `__init__` docstring
|
||||
- **Methods:** Include `Args:` and `Returns:` sections
|
||||
- **Properties:** Must have `Returns:` section
|
||||
- **Examples:** Use `Examples:` section with `::` syntax
|
||||
|
||||
### File Organization
|
||||
```
|
||||
src/pipecat/ # Main package
|
||||
├── processors/ # Frame processors
|
||||
├── services/ # AI service integrations
|
||||
├── transports/ # Communication layers
|
||||
├── frames/ # Data frame definitions
|
||||
└── pipeline/ # Pipeline orchestration
|
||||
|
||||
examples/foundational/ # Step-by-step tutorials
|
||||
tests/ # Test suite
|
||||
```
|
||||
|
||||
## Testing Instructions
|
||||
|
||||
### Test Structure
|
||||
- **Unit Tests:** Test individual components in isolation
|
||||
- **Integration Tests:** Test service interactions
|
||||
- **Example Tests:** Validate foundational examples work
|
||||
|
||||
### Adding Tests
|
||||
```bash
|
||||
# Test naming convention
|
||||
test_<component>_<functionality>.py
|
||||
|
||||
# Run specific test pattern
|
||||
uv run pytest -k "test_pipeline"
|
||||
|
||||
# Run with debugging
|
||||
uv run pytest -s -vv tests/test_name.py::test_function
|
||||
```
|
||||
|
||||
### Pre-commit Requirements
|
||||
All commits must pass:
|
||||
- Ruff formatting
|
||||
- Ruff linting
|
||||
- Type checking
|
||||
- Basic test suite
|
||||
|
||||
## Dependency Management
|
||||
|
||||
### Using uv (Recommended)
|
||||
```bash
|
||||
# Add runtime dependency
|
||||
uv add package-name
|
||||
|
||||
# Add optional dependency
|
||||
uv add --optional service package-name
|
||||
|
||||
# Add development dependency
|
||||
uv add --group dev package-name
|
||||
|
||||
# Update lockfile
|
||||
uv lock
|
||||
|
||||
# Sync dependencies
|
||||
uv sync
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
- **Always commit both `pyproject.toml` and `uv.lock` together**
|
||||
- **Never manually edit `uv.lock`** - it's auto-generated
|
||||
- **Use extras for optional service dependencies** (e.g., `[openai]`, `[cartesia]`)
|
||||
|
||||
## Project Structure Guidelines
|
||||
|
||||
### Service Integration
|
||||
When adding new AI services:
|
||||
1. Create service class in `src/pipecat/services/<provider>/`
|
||||
2. Follow existing patterns (e.g., STTService, LLMService)
|
||||
3. Add to appropriate extras in `pyproject.toml`
|
||||
4. Include tests in `tests/`
|
||||
5. Add documentation examples
|
||||
|
||||
### Frame Processing
|
||||
For custom processors:
|
||||
1. Inherit from `FrameProcessor`
|
||||
2. Implement `process_frame()` method. ALWAYS explicitly call `await super().process_frame(frame, direction)` at the top of this method.
|
||||
3. Handle frame direction (FrameDirection.UPSTREAM/DOWNSTREAM)
|
||||
4. Add proper type hints and docstrings
|
||||
|
||||
### Transport Implementation
|
||||
For new transport layers:
|
||||
1. Inherit from `BaseTransport`
|
||||
2. Implement required abstract methods
|
||||
3. Handle connection lifecycle
|
||||
4. Support both input and output streams
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### API Keys
|
||||
- **Never commit API keys** to the repository
|
||||
- **Use environment variables** for all secrets
|
||||
- **Reference `env.example`** for required variables
|
||||
- **Use `.env` files** for local development
|
||||
|
||||
### Input Validation
|
||||
- **Validate all external inputs** (audio, text, API responses)
|
||||
- **Sanitize user data** before processing
|
||||
- **Handle rate limiting** for external services
|
||||
- **Implement proper timeout handling**
|
||||
|
||||
## Performance Guidelines
|
||||
|
||||
### Memory Management
|
||||
- **Clean up resources** in transport disconnection handlers
|
||||
- **Use async context managers** for service connections
|
||||
- **Implement proper frame lifecycle** management
|
||||
|
||||
### Latency Optimization
|
||||
- **Choose appropriate STT services** for latency requirements
|
||||
- **Use streaming TTS** when possible
|
||||
- **Implement connection pooling** for HTTP services
|
||||
- **Consider WebRTC** for real-time applications
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Error Handling
|
||||
```python
|
||||
@transport.event_handler("on_error")
|
||||
async def on_error(transport, error):
|
||||
logger.error(f"Transport error: {error}")
|
||||
|
||||
# Shutdown the pipeline
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
```
|
||||
|
||||
### Service Configuration
|
||||
```python
|
||||
# Use environment variables for configuration
|
||||
service = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY", ""),
|
||||
model="gpt-4o",
|
||||
params={"temperature": 0.7}
|
||||
)
|
||||
```
|
||||
|
||||
### Pipeline Assembly
|
||||
```python
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt_service,
|
||||
context_aggregator.user(),
|
||||
llm_service,
|
||||
tts_service,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
```
|
||||
|
||||
## Commit and PR Guidelines
|
||||
|
||||
### Commit Message Format
|
||||
```
|
||||
<type>(<scope>): <description>
|
||||
|
||||
[optional body]
|
||||
|
||||
[optional footer]
|
||||
```
|
||||
|
||||
Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
|
||||
|
||||
### PR Requirements
|
||||
- **All tests must pass**
|
||||
- **Code must be properly formatted** (Ruff)
|
||||
- **Include appropriate tests** for new functionality
|
||||
- **Update documentation** if needed
|
||||
- **Reference related issues** in description
|
||||
|
||||
### Review Process
|
||||
1. Automated checks must pass
|
||||
2. Manual code review by maintainers
|
||||
3. Documentation review for user-facing changes
|
||||
4. Integration testing for service additions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
- **Import errors:** Run `uv sync` to ensure dependencies are installed
|
||||
- **Test failures:** Check environment variables in `.env`
|
||||
- **Format errors:** Run `uv run ruff format` before committing
|
||||
- **Type errors:** Ensure all public methods have type hints
|
||||
|
||||
### Development Tips
|
||||
- **Use foundational examples** as starting points for testing
|
||||
- **Check existing services** for integration patterns
|
||||
- **Run tests frequently** during development
|
||||
- **Use IDE integration** for Ruff formatting
|
||||
|
||||
### Getting Help
|
||||
- **Documentation:** [docs.pipecat.ai](https://docs.pipecat.ai)
|
||||
- **Issues:** [GitHub Issues](https://github.com/pipecat-ai/pipecat/issues)
|
||||
33
CHANGELOG.md
33
CHANGELOG.md
@@ -5,6 +5,39 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
|
||||
fired when the pipeline is done running. This can be the result of a
|
||||
`StopFrame`, `CancelFrame` or `EndFrame`.
|
||||
|
||||
```python
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
...
|
||||
```
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
|
||||
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
|
||||
instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
|
||||
when a bot speaks and user input is blocked.
|
||||
|
||||
- Fixed a `FastAPIWebsocketTransport` and `SmallWebRTCTransport` issue where
|
||||
`on_client_disconnected` would be triggered when the bot ends the
|
||||
conversation. That is, `on_client_disconnected` should only be triggered when
|
||||
the remote client actually disconnects.
|
||||
|
||||
- Fixed an issue in `HeyGenVideoService` where the `BotStartedSpeakingFrame`
|
||||
was blocked from moving through the Pipeline.
|
||||
|
||||
## [0.0.85] - 2025-09-12
|
||||
|
||||
### Added
|
||||
|
||||
205
examples/foundational/45-openai-agent-basic.py
Normal file
205
examples/foundational/45-openai-agent-basic.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Basic OpenAI Agent service example.
|
||||
|
||||
This example demonstrates how to use the OpenAI Agents SDK within a Pipecat
|
||||
pipeline to create an interactive agent with tool calling capabilities.
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key
|
||||
- OpenAI Agents SDK: pip install openai-agents
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, List
|
||||
|
||||
# Import agents SDK for tools and agent creation
|
||||
from agents import Agent, function_tool
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.frames.frames import LLMRunFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Transport configuration
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the current weather for a location.
|
||||
|
||||
Args:
|
||||
location: The location to get weather for
|
||||
|
||||
Returns:
|
||||
A weather description string
|
||||
"""
|
||||
# Mock weather data - in real usage, integrate with weather API
|
||||
weather_data = {
|
||||
"San Francisco": "Foggy, 65°F",
|
||||
"New York": "Sunny, 72°F",
|
||||
"London": "Rainy, 59°F",
|
||||
"Tokyo": "Partly cloudy, 68°F",
|
||||
}
|
||||
return weather_data.get(location, f"Weather data not available for {location}")
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact.
|
||||
|
||||
Returns:
|
||||
A random fact string
|
||||
"""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
|
||||
def get_random_fact_tool():
|
||||
"""Example tool function for random facts."""
|
||||
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact.
|
||||
|
||||
Returns:
|
||||
A random fact string.
|
||||
"""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"A group of flamingos is called a 'flamboyance'.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
return get_random_fact
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting OpenAI Agent bot")
|
||||
|
||||
# Set up STT for speech recognition
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
model="nova-2",
|
||||
)
|
||||
|
||||
# Set up TTS for voice output
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY", ""),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Create tools for the agent
|
||||
tools: list[Any] = [
|
||||
get_weather,
|
||||
get_random_fact,
|
||||
]
|
||||
|
||||
# Create the agent with tools
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="""You are a helpful assistant with access to weather information and random facts.
|
||||
You can:
|
||||
- Check weather for any location using the get_weather tool
|
||||
- Share interesting facts using the get_random_fact tool
|
||||
- Have natural conversations
|
||||
|
||||
Be friendly, informative, and engaging in your responses.""",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Initialize the OpenAI Agent service with the pre-configured agent
|
||||
agent_service = OpenAIAgentService(
|
||||
agent=agent,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Set up conversation context with initial system message
|
||||
messages: List[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant with access to weather information and random facts. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = agent_service.create_context_aggregator(context)
|
||||
|
||||
# Create the processing pipeline with context aggregators
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # Speech to text
|
||||
context_aggregator.user(), # User responses
|
||||
agent_service, # OpenAI Agent processing
|
||||
tts, # Text to speech
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Send an initial greeting when client connects
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected, sending greeting")
|
||||
# Kick off the conversation by adding system message and running LLM
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
276
examples/foundational/46-openai-agent-handoffs.py
Normal file
276
examples/foundational/46-openai-agent-handoffs.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Advanced OpenAI Agent service example with handoffs.
|
||||
|
||||
This example demonstrates how to use multiple agents with handoffs in the
|
||||
OpenAI Agents SDK within a Pipecat pipeline, showcasing agent orchestration
|
||||
and specialization.
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key
|
||||
- OpenAI Agents SDK: pip install openai-agents
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.frames.frames import LLMRunFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Transport configuration
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
def create_weather_tools():
|
||||
"""Create weather-related tools."""
|
||||
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get current weather for a location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "snowy", "windy"]
|
||||
temp = random.randint(-10, 35)
|
||||
condition = random.choice(conditions)
|
||||
return f"The weather in {location} is {condition} with a temperature of {temp}°C."
|
||||
|
||||
def get_forecast(location: str, days: int = 3) -> str:
|
||||
"""Get weather forecast for multiple days."""
|
||||
forecast = []
|
||||
for i in range(days):
|
||||
conditions = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
temp = random.randint(-5, 30)
|
||||
condition = random.choice(conditions)
|
||||
day = "today" if i == 0 else f"in {i} day{'s' if i > 1 else ''}"
|
||||
forecast.append(f"{day.capitalize()}: {condition}, {temp}°C")
|
||||
return f"Weather forecast for {location}:\n" + "\n".join(forecast)
|
||||
|
||||
return [get_weather, get_forecast]
|
||||
|
||||
|
||||
def create_trivia_tools():
|
||||
"""Create trivia and fact tools."""
|
||||
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact."""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"A group of flamingos is called a 'flamboyance'.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
"Wombat poop is cube-shaped.",
|
||||
"A shrimp's heart is in its head.",
|
||||
"It's impossible to hum while holding your nose.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
def get_science_fact() -> str:
|
||||
"""Get a random science fact."""
|
||||
facts = [
|
||||
"The speed of light in a vacuum is approximately 299,792,458 meters per second.",
|
||||
"DNA stands for Deoxyribonucleic Acid.",
|
||||
"The human brain uses about 20% of the body's total energy.",
|
||||
"There are more possible games of chess than atoms in the observable universe.",
|
||||
"A single bolt of lightning contains enough energy to toast 100,000 slices of bread.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
return [get_random_fact, get_science_fact]
|
||||
|
||||
|
||||
def create_math_tools():
|
||||
"""Create math calculation tools."""
|
||||
|
||||
def calculate(expression: str) -> str:
|
||||
"""Safely calculate a mathematical expression."""
|
||||
try:
|
||||
# Only allow basic math operations for safety
|
||||
allowed_chars = set("0123456789+-*/.() ")
|
||||
if not all(c in allowed_chars for c in expression):
|
||||
return "Sorry, I can only calculate basic math expressions with +, -, *, /, and parentheses."
|
||||
|
||||
result = eval(expression)
|
||||
return f"{expression} = {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating '{expression}': {str(e)}"
|
||||
|
||||
def generate_math_problem() -> str:
|
||||
"""Generate a random math problem."""
|
||||
operations = ["+", "-", "*"]
|
||||
a = random.randint(1, 20)
|
||||
b = random.randint(1, 20)
|
||||
op = random.choice(operations)
|
||||
|
||||
if op == "+":
|
||||
answer = a + b
|
||||
elif op == "-":
|
||||
answer = a - b
|
||||
else: # multiplication
|
||||
answer = a * b
|
||||
|
||||
return f"Here's a math problem for you: {a} {op} {b} = ?"
|
||||
|
||||
return [calculate, generate_math_problem]
|
||||
|
||||
|
||||
async def create_specialist_agents():
|
||||
"""Create specialized agents for different domains."""
|
||||
|
||||
# Weather specialist agent
|
||||
weather_agent = OpenAIAgentService(
|
||||
name="Weather Specialist",
|
||||
instructions="""You are a weather specialist. You provide detailed weather information,
|
||||
forecasts, and weather-related advice. Use your tools to get accurate weather data.
|
||||
Be informative and helpful about weather conditions and what they might mean for
|
||||
outdoor activities.""",
|
||||
tools=create_weather_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Trivia specialist agent
|
||||
trivia_agent = OpenAIAgentService(
|
||||
name="Trivia Master",
|
||||
instructions="""You are a trivia and facts specialist. You love sharing interesting
|
||||
facts, trivia, and educational content. Use your tools to provide fascinating
|
||||
information and engage users with fun facts. Make learning enjoyable!""",
|
||||
tools=create_trivia_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Math specialist agent
|
||||
math_agent = OpenAIAgentService(
|
||||
name="Math Helper",
|
||||
instructions="""You are a mathematics specialist. You help with calculations,
|
||||
math problems, and mathematical concepts. Use your tools to solve problems
|
||||
and generate practice questions. Make math accessible and fun!""",
|
||||
tools=create_math_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return weather_agent, trivia_agent, math_agent
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting OpenAI Agent bot with handoffs")
|
||||
|
||||
# Set up STT for speech recognition
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
model="nova-2",
|
||||
)
|
||||
|
||||
# Set up TTS for voice output
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY", ""),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Create specialist agents
|
||||
weather_agent, trivia_agent, math_agent = await create_specialist_agents()
|
||||
|
||||
# Create the main triage agent that can hand off to specialists
|
||||
triage_agent = OpenAIAgentService(
|
||||
name="Assistant Coordinator",
|
||||
instructions="""You are a helpful assistant coordinator. Your role is to understand
|
||||
what the user needs and direct them to the right specialist:
|
||||
|
||||
- For weather questions, forecasts, or outdoor activity planning -> Weather Specialist
|
||||
- For interesting facts, trivia, or educational content -> Trivia Master
|
||||
- For calculations, math problems, or mathematical help -> Math Helper
|
||||
|
||||
If the request doesn't clearly fit a specialist, you can handle general conversation
|
||||
yourself. Always be friendly and explain when you're connecting them to a specialist.""",
|
||||
handoffs=[weather_agent.agent, trivia_agent.agent, math_agent.agent], # type: ignore
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Set up conversation context with initial system message
|
||||
messages: List[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant coordinator with access to weather information, trivia, and math tools. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = triage_agent.create_context_aggregator(context)
|
||||
|
||||
# Create the processing pipeline with context aggregators
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # Speech to text
|
||||
context_aggregator.user(), # User responses
|
||||
triage_agent, # OpenAI Agent processing
|
||||
tts, # Text to speech
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Send an initial greeting when client connects
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected, sending greeting")
|
||||
# Kick off the conversation by adding system message and running LLM
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please introduce yourself to the user as an AI assistant coordinator who works with specialists for weather, trivia, and math topics.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -34,7 +34,7 @@ dependencies = [
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai>=1.74.0,<=1.99.1",
|
||||
"openai>=1.74.0,<2.0.0",
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
"wait_for2>=0.4.1; python_version<'3.12'",
|
||||
@@ -74,7 +74,7 @@ langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-ope
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "websockets>=13.1,<15.0" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mcp = [ "mcp[cli]>=1.11.0,<2.0.0" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
@@ -83,7 +83,8 @@ nim = []
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "websockets>=13.1,<15.0" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openai-agent = [ "openai-agents~=0.3.0" ]
|
||||
# openpipe = [ "openpipe~=4.50.0" ] # Temporarily disabled due to openai version conflict
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "websockets>=13.1,<15.0" ]
|
||||
|
||||
@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
# Should be a TypedDict
|
||||
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
|
||||
@@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
Subclasses must implement provider-specific conversion logic.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
|
||||
|
||||
Returns:
|
||||
The identifier string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
@@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
|
||||
|
||||
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
"""Get messages from the LLM context, including standard and LLM-specific messages.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages including standard and LLM-specific messages.
|
||||
"""
|
||||
return context.get_messages(self.id_for_llm_specific_messages)
|
||||
|
||||
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
|
||||
"""Convert tools from standard format to provider format.
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
to the specific format required by Anthropic's Claude models for function calling.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
|
||||
return "anthropic"
|
||||
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, enable_prompt_caching: bool
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
@@ -54,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": (
|
||||
@@ -78,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Anthropic.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -92,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
specific function-calling format, enabling tool use with Nova Sonic models.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
into AWS Bedrock's expected tool format for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
|
||||
return "aws"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
|
||||
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -51,7 +56,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Bedrock's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": messages.messages,
|
||||
@@ -75,7 +80,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about AWS Bedrock.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -89,9 +94,6 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Google."""
|
||||
return "google"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Gemini.
|
||||
"""
|
||||
# Get messages in Gemini's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
messages_for_logging.append(obj)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("google")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
@@ -47,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
|
||||
return "openai"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
|
||||
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -57,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
return {
|
||||
"messages": self._from_universal_context_messages(self._get_messages(context)),
|
||||
"messages": self._from_universal_context_messages(self.get_messages(context)),
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
"tool_choice": context.tool_choice,
|
||||
@@ -91,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
List of messages in a format ready for logging about OpenAI.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self._get_messages(context):
|
||||
for message in self.get_messages(context):
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
@@ -104,14 +110,18 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("openai")
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, messages: List[LLMContextMessage]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
# Just a pass-through: messages are already the right type
|
||||
return messages
|
||||
result = []
|
||||
for message in messages:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
# Extract the actual message content from LLMSpecificMessage
|
||||
result.append(message.message)
|
||||
else:
|
||||
# Standard message, pass through unchanged
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
def _from_standard_tool_choice(
|
||||
self, tool_choice: LLMContextToolChoice | NotGiven
|
||||
|
||||
@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
OpenAI's Realtime API for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
|
||||
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -115,9 +115,28 @@ class PipelineTask(BasePipelineTask):
|
||||
- on_frame_reached_downstream: Called when downstream frames reach the sink
|
||||
- on_idle_timeout: Called when pipeline is idle beyond timeout threshold
|
||||
- on_pipeline_started: Called when pipeline starts with StartFrame
|
||||
- on_pipeline_stopped: Called when pipeline stops with StopFrame
|
||||
- on_pipeline_ended: Called when pipeline ends with EndFrame
|
||||
- on_pipeline_cancelled: Called when pipeline is cancelled
|
||||
- on_pipeline_stopped: [deprecated] Called when pipeline stops with StopFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_ended: [deprecated] Called when pipeline ends with EndFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_cancelled: [deprecated] Called when pipeline is cancelled with CancelFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_finished: Called after the pipeline has reached any terminal state.
|
||||
This includes:
|
||||
- StopFrame: pipeline was stopped (processors keep connections open)
|
||||
- EndFrame: pipeline ended normally
|
||||
- CancelFrame: pipeline was cancelled
|
||||
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
|
||||
the frame if they need to handle specific cases.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -128,6 +147,10 @@ class PipelineTask(BasePipelineTask):
|
||||
@task.event_handler("on_idle_timeout")
|
||||
async def on_pipeline_idle_timeout(task):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -264,6 +287,7 @@ class PipelineTask(BasePipelineTask):
|
||||
self._register_event_handler("on_pipeline_stopped")
|
||||
self._register_event_handler("on_pipeline_ended")
|
||||
self._register_event_handler("on_pipeline_cancelled")
|
||||
self._register_event_handler("on_pipeline_finished")
|
||||
|
||||
@property
|
||||
def params(self) -> PipelineParams:
|
||||
@@ -292,6 +316,27 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
return self._turn_trace_observer
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
"""Decorator for registering event handlers.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to handle.
|
||||
|
||||
Returns:
|
||||
The decorator function that registers the handler.
|
||||
"""
|
||||
if event_name in ["on_pipeline_stopped", "on_pipeline_ended", "on_pipeline_cancelled"]:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
f"Event '{event_name}' is deprecated, use 'on_pipeline_finished' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return super().event_handler(event_name)
|
||||
|
||||
def add_observer(self, observer: BaseObserver):
|
||||
"""Add an observer to monitor pipeline execution.
|
||||
|
||||
@@ -534,6 +579,7 @@ class PipelineTask(BasePipelineTask):
|
||||
)
|
||||
finally:
|
||||
await self._call_event_handler("on_pipeline_cancelled", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
|
||||
logger.debug(f"{self}: Closing. Waiting for {frame} to reach the end of the pipeline...")
|
||||
|
||||
@@ -681,9 +727,11 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_start_event.set()
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._call_event_handler("on_pipeline_ended", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, StopFrame):
|
||||
await self._call_event_handler("on_pipeline_stopped", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
|
||||
@@ -137,12 +137,12 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
return self._num_channels
|
||||
|
||||
def has_audio(self) -> bool:
|
||||
"""Check if both user and bot audio buffers contain data.
|
||||
"""Check if either user or bot audio buffers contain data.
|
||||
|
||||
Returns:
|
||||
True if both buffers contain audio data.
|
||||
True if either buffer contains audio data.
|
||||
"""
|
||||
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
|
||||
return self._buffer_has_audio(self._user_audio_buffer) or self._buffer_has_audio(
|
||||
self._bot_audio_buffer
|
||||
)
|
||||
|
||||
|
||||
@@ -811,60 +811,55 @@ class AWSBedrockLLMService(LLMService):
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
try:
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}]
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}]
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _create_converse_stream(self, client, request_params):
|
||||
|
||||
@@ -240,6 +240,7 @@ class HeyGenVideoService(AIService):
|
||||
# As soon as we receive actual audio, the base output transport will create a
|
||||
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -195,6 +195,17 @@ class LLMService(AIService):
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return self.get_llm_adapter().create_llm_specific_message(message)
|
||||
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
|
||||
209
src/pipecat/services/openai_agent/README.md
Normal file
209
src/pipecat/services/openai_agent/README.md
Normal file
@@ -0,0 +1,209 @@
|
||||
# OpenAI Agents SDK Integration
|
||||
|
||||
This service integrates the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) with Pipecat, enabling powerful agentic workflows with features like:
|
||||
|
||||
- **Agent loops** with tool calling and response streaming
|
||||
- **Handoffs** between specialized agents
|
||||
- **Guardrails** for input/output validation
|
||||
- **Sessions** with automatic conversation history
|
||||
- **Built-in tracing** and monitoring
|
||||
|
||||
## Installation
|
||||
|
||||
Install the OpenAI Agents SDK dependency:
|
||||
|
||||
```bash
|
||||
pip install "pipecat-ai[openai-agent]"
|
||||
# or
|
||||
uv add "pipecat-ai[openai-agent]"
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from pipecat.services.openai_agent import OpenAIAgentService
|
||||
|
||||
# Create a simple agent
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Use in a pipeline
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
agent_service,
|
||||
tts,
|
||||
transport.output(),
|
||||
])
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Tool Integration
|
||||
|
||||
```python
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny, 22°C"
|
||||
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Weather Assistant",
|
||||
instructions="Help users with weather information.",
|
||||
tools=[get_weather],
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
```
|
||||
|
||||
### Agent Handoffs
|
||||
|
||||
```python
|
||||
# Create specialized agents
|
||||
weather_agent = OpenAIAgentService(
|
||||
name="Weather Specialist",
|
||||
instructions="Provide weather information and forecasts.",
|
||||
tools=[get_weather, get_forecast],
|
||||
)
|
||||
|
||||
trivia_agent = OpenAIAgentService(
|
||||
name="Trivia Master",
|
||||
instructions="Share interesting facts and trivia.",
|
||||
tools=[get_random_fact],
|
||||
)
|
||||
|
||||
# Create coordinator that can hand off to specialists
|
||||
coordinator = OpenAIAgentService(
|
||||
name="Coordinator",
|
||||
instructions="Route users to the right specialist.",
|
||||
handoffs=[weather_agent.agent, trivia_agent.agent],
|
||||
)
|
||||
```
|
||||
|
||||
### Guardrails
|
||||
|
||||
```python
|
||||
from agents import InputGuardrail, GuardrailFunctionOutput
|
||||
|
||||
async def content_filter(ctx, agent, input_data):
|
||||
# Check input for appropriate content
|
||||
if is_inappropriate(input_data):
|
||||
return GuardrailFunctionOutput(
|
||||
tripwire_triggered=True,
|
||||
output_info="Content not allowed"
|
||||
)
|
||||
return GuardrailFunctionOutput(tripwire_triggered=False)
|
||||
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Safe Assistant",
|
||||
instructions="You are a helpful and safe assistant.",
|
||||
input_guardrails=[InputGuardrail(guardrail_function=content_filter)],
|
||||
)
|
||||
```
|
||||
|
||||
### Session Management
|
||||
|
||||
```python
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Personal Assistant",
|
||||
instructions="Remember user preferences and context.",
|
||||
session_config={
|
||||
"user_id": "user_123",
|
||||
"memory_enabled": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Update session context dynamically
|
||||
agent_service.update_session_context({
|
||||
"user_preferences": {"language": "en", "style": "formal"}
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Basic Parameters
|
||||
|
||||
- `name`: Agent identifier for handoffs and tracing
|
||||
- `instructions`: System prompt defining agent behavior
|
||||
- `api_key`: OpenAI API key (or use `OPENAI_API_KEY` env var)
|
||||
- `streaming`: Enable real-time token streaming (default: True)
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
- `tools`: List of callable functions for the agent to use
|
||||
- `handoffs`: List of other agents this agent can transfer to
|
||||
- `input_guardrails`: Input validation and filtering
|
||||
- `output_guardrails`: Output validation and filtering
|
||||
- `model_config`: Model settings (model, temperature, etc.)
|
||||
- `session_config`: Session and memory configuration
|
||||
|
||||
### Model Configuration
|
||||
|
||||
```python
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Precise Assistant",
|
||||
instructions="Provide accurate, concise responses.",
|
||||
model_config={
|
||||
"model": "gpt-4o",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 150,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the foundational examples:
|
||||
|
||||
- [`45-openai-agent-basic.py`](../examples/foundational/45-openai-agent-basic.py) - Basic agent with tools
|
||||
- [`46-openai-agent-handoffs.py`](../examples/foundational/46-openai-agent-handoffs.py) - Multi-agent system with handoffs
|
||||
|
||||
## Methods
|
||||
|
||||
### Core Methods
|
||||
|
||||
- `update_agent_config()` - Update instructions and model settings
|
||||
- `add_tool()` - Add new tools dynamically
|
||||
- `add_handoff_agent()` - Add handoff destinations
|
||||
- `get_session_context()` - Get current session state
|
||||
- `update_session_context()` - Update session variables
|
||||
|
||||
### Lifecycle Methods
|
||||
|
||||
Inherited from `AIService`:
|
||||
- `start()` - Initialize the agent
|
||||
- `stop()` - Clean up resources
|
||||
- `cancel()` - Cancel ongoing operations
|
||||
|
||||
## Integration with Pipecat
|
||||
|
||||
The service processes `TextFrame` inputs and generates:
|
||||
- `LLMFullResponseStartFrame` - Response beginning
|
||||
- `LLMTextFrame` - Streaming text tokens (if streaming enabled)
|
||||
- `LLMFullResponseEndFrame` - Response completion
|
||||
|
||||
This integrates seamlessly with Pipecat's conversation pipeline and context aggregators.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The service includes robust error handling for:
|
||||
- Missing API keys or SDK installation
|
||||
- Agent processing failures
|
||||
- Network connectivity issues
|
||||
- Malformed tool responses
|
||||
|
||||
Errors are emitted as `ErrorFrame` objects in the pipeline.
|
||||
|
||||
## Requirements
|
||||
|
||||
- OpenAI API key
|
||||
- `openai-agents` package
|
||||
- Python 3.10+
|
||||
|
||||
## Limitations
|
||||
|
||||
- Currently supports OpenAI models only (via Agents SDK)
|
||||
- Handoffs work within individual requests (no cross-request state)
|
||||
- Real-time voice features require additional setup
|
||||
11
src/pipecat/services/openai_agent/__init__.py
Normal file
11
src/pipecat/services/openai_agent/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Agents SDK service for Pipecat integration."""
|
||||
|
||||
from .agent_service import OpenAIAgentService
|
||||
|
||||
__all__ = ["OpenAIAgentService"]
|
||||
567
src/pipecat/services/openai_agent/agent_service.py
Normal file
567
src/pipecat/services/openai_agent/agent_service.py
Normal file
@@ -0,0 +1,567 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Agents SDK integration service.
|
||||
|
||||
Provides integration with the OpenAI Agents SDK for building AI applications
|
||||
within Pipecat pipelines. This service allows leveraging agent loops, handoffs,
|
||||
guardrails, sessions, and tools from the OpenAI Agents SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Union,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from agents import Agent, InputGuardrail, OutputGuardrail, Runner, Tool
|
||||
from agents.result import RunResult, RunResultStreaming
|
||||
from agents.stream_events import StreamEvent
|
||||
except ImportError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI Agents SDK, you need to `pip install openai-agents`. "
|
||||
"Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolLike(Protocol):
|
||||
"""Protocol for tool-like objects."""
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Tool call interface."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentLike(Protocol):
|
||||
"""Protocol for agent-like objects."""
|
||||
|
||||
name: str
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Agent call interface."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIAgentContextAggregatorPair:
|
||||
"""Pair of OpenAI Agent context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
_user: "OpenAIAgentUserContextAggregator"
|
||||
_assistant: "OpenAIAgentAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIAgentUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAgentAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAIAgentService(AIService):
|
||||
"""OpenAI Agents SDK service for Pipecat.
|
||||
|
||||
Integrates the OpenAI Agents SDK with Pipecat's pipeline architecture,
|
||||
enabling advanced agentic workflows with features like handoffs, guardrails,
|
||||
sessions, and tools within real-time conversational AI applications.
|
||||
|
||||
The service processes text input frames and generates streaming responses
|
||||
using the agent's configured capabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Optional[Agent] = None,
|
||||
name: str = "Assistant",
|
||||
instructions: Union[str, Sequence[str]] = "You are a helpful assistant.",
|
||||
handoffs: Optional[Sequence[AgentLike]] = None,
|
||||
tools: Optional[Sequence[ToolLike]] = None,
|
||||
input_guardrails: Optional[Sequence[InputGuardrail]] = None,
|
||||
output_guardrails: Optional[Sequence[OutputGuardrail]] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None,
|
||||
session_config: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
streaming: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI Agent service.
|
||||
|
||||
Args:
|
||||
agent: Pre-configured Agent instance. If provided, other agent configuration
|
||||
parameters will be ignored.
|
||||
name: Name of the agent for identification and handoffs.
|
||||
instructions: System instructions that define the agent's behavior.
|
||||
handoffs: List of other agents this agent can hand off to.
|
||||
tools: List of callable functions the agent can use as tools.
|
||||
input_guardrails: List of input validation guardrails.
|
||||
output_guardrails: List of output validation guardrails.
|
||||
model_config: Configuration for the underlying language model.
|
||||
session_config: Configuration for session management.
|
||||
api_key: OpenAI API key. If not provided, will use OPENAI_API_KEY env var.
|
||||
streaming: Whether to use streaming responses for real-time output.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set up API key
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
elif not os.getenv("OPENAI_API_KEY"):
|
||||
logger.warning("No OpenAI API key provided. Set OPENAI_API_KEY environment variable.")
|
||||
|
||||
# Create or use existing agent
|
||||
if agent:
|
||||
self._agent = agent
|
||||
else:
|
||||
# Convert sequences to lists and handle string instructions
|
||||
agent_handoffs: List[Any] = list(handoffs) if handoffs else []
|
||||
agent_tools: List[Any] = list(tools) if tools else []
|
||||
agent_input_guardrails: List[Any] = list(input_guardrails) if input_guardrails else []
|
||||
agent_output_guardrails: List[Any] = (
|
||||
list(output_guardrails) if output_guardrails else []
|
||||
)
|
||||
|
||||
# Handle instructions - convert sequence to string if needed
|
||||
if isinstance(instructions, str):
|
||||
agent_instructions = instructions
|
||||
else:
|
||||
agent_instructions = " ".join(str(instr) for instr in instructions)
|
||||
|
||||
self._agent = Agent(
|
||||
name=name,
|
||||
instructions=agent_instructions,
|
||||
handoffs=agent_handoffs,
|
||||
tools=agent_tools,
|
||||
input_guardrails=agent_input_guardrails,
|
||||
output_guardrails=agent_output_guardrails,
|
||||
model=model_config.get("model", "gpt-4o") if model_config else "gpt-4o",
|
||||
)
|
||||
|
||||
self._streaming = streaming
|
||||
self._session_config = session_config or {}
|
||||
self._current_session = None
|
||||
self._accumulated_text = ""
|
||||
|
||||
# Set model name for metrics
|
||||
if model_config and "model" in model_config:
|
||||
self.set_model_name(model_config["model"])
|
||||
else:
|
||||
self.set_model_name("gpt-4o") # Default model
|
||||
|
||||
logger.info(f"Initialized OpenAI Agent service: {self._agent.name}")
|
||||
|
||||
@property
|
||||
def agent(self) -> Agent:
|
||||
"""Get the underlying OpenAI Agent.
|
||||
|
||||
Returns:
|
||||
The configured Agent instance.
|
||||
"""
|
||||
return self._agent
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIAgentContextAggregatorPair:
|
||||
"""Create OpenAI-specific context aggregators for agent interactions.
|
||||
|
||||
Creates a pair of context aggregators optimized for OpenAI Agent interactions,
|
||||
including support for function calls, tool usage, and conversation management.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
OpenAIAgentContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIAgentContextAggregatorPair.
|
||||
"""
|
||||
user = OpenAIAgentUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAgentAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIAgentContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def update_agent_config(
|
||||
self,
|
||||
*,
|
||||
instructions: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Update agent configuration dynamically.
|
||||
|
||||
Args:
|
||||
instructions: New system instructions for the agent.
|
||||
model_config: Updated model configuration.
|
||||
**kwargs: Additional agent configuration parameters.
|
||||
"""
|
||||
if instructions:
|
||||
self._agent.instructions = instructions
|
||||
logger.info(f"Updated agent instructions for {self._agent.name}")
|
||||
|
||||
if model_config:
|
||||
# Note: OpenAI Agents SDK handles model configuration during agent creation
|
||||
# We can't update model_config after agent is created, but we can update our model name
|
||||
if "model" in model_config:
|
||||
self.set_model_name(model_config["model"])
|
||||
logger.info(f"Updated model config for {self._agent.name}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the OpenAI Agent service.
|
||||
|
||||
Initializes the agent session and prepares for processing.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
logger.info(f"Starting OpenAI Agent service: {self._agent.name}")
|
||||
await super().start(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the OpenAI Agent service.
|
||||
|
||||
Cleans up resources and ends the current session.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
logger.info(f"Stopping OpenAI Agent service: {self._agent.name}")
|
||||
await super().stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the OpenAI Agent service.
|
||||
|
||||
Cancels any ongoing operations.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
logger.info(f"Cancelling OpenAI Agent service: {self._agent.name}")
|
||||
await super().cancel(frame)
|
||||
|
||||
@override
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
"""Process frames and handle agent interactions.
|
||||
|
||||
Processes OpenAILLMContextFrame and TextFrame by running them through the OpenAI Agent
|
||||
and streams the results back as LLM frames.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
# Process context frame through the agent
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
# Extract the latest user message from the context
|
||||
messages = frame.context.get_messages()
|
||||
if messages:
|
||||
# Get the last user message
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Extract text from content array
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
user_input = " ".join(text_parts)
|
||||
else:
|
||||
user_input = str(content)
|
||||
|
||||
if user_input.strip():
|
||||
await self._process_agent_request(user_input)
|
||||
break
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent context: {e}")
|
||||
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
|
||||
elif isinstance(frame, TextFrame):
|
||||
# Process text input through the agent directly (for backwards compatibility)
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self._process_agent_request(frame.text)
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent request: {e}")
|
||||
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
|
||||
else:
|
||||
# For frames we don't handle, pass them through with direction
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_agent_request(self, input_text: str):
|
||||
"""Process an agent request and stream the results.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
logger.debug(f"Processing agent request: {input_text}")
|
||||
|
||||
if self._streaming:
|
||||
await self._process_streaming_response(input_text)
|
||||
else:
|
||||
await self._process_non_streaming_response(input_text)
|
||||
|
||||
async def _process_streaming_response(self, input_text: str):
|
||||
"""Process a streaming agent response.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
try:
|
||||
# Run the agent with streaming
|
||||
result: RunResultStreaming = Runner.run_streamed(
|
||||
self._agent, input_text, context=self._session_config
|
||||
)
|
||||
|
||||
has_streaming_deltas = False
|
||||
|
||||
# Process the stream events
|
||||
async for event in result.stream_events():
|
||||
if event.type == "raw_response_event":
|
||||
# Handle token-by-token streaming
|
||||
# Only check for delta on events that are known to have it
|
||||
if hasattr(event.data, "delta") and getattr(event.data, "delta", None):
|
||||
delta_text = getattr(event.data, "delta", "")
|
||||
if delta_text:
|
||||
has_streaming_deltas = True
|
||||
self._accumulated_text += delta_text
|
||||
await self.push_frame(LLMTextFrame(text=delta_text))
|
||||
|
||||
elif event.type == "run_item_stream_event":
|
||||
# Handle completed items
|
||||
if event.item.type == "message_output_item":
|
||||
# Only process complete message if we didn't get streaming deltas
|
||||
if not has_streaming_deltas:
|
||||
message_text = self._extract_message_text(event.item)
|
||||
logger.debug(
|
||||
f"Processing complete message (no deltas): {message_text[:50]}..."
|
||||
if len(message_text) > 50
|
||||
else f"Processing complete message: {message_text}"
|
||||
)
|
||||
if message_text:
|
||||
await self.push_frame(LLMTextFrame(text=message_text))
|
||||
|
||||
elif event.item.type == "tool_call_item":
|
||||
# Use getattr for safe attribute access
|
||||
tool_name = getattr(event.item, "tool_name", "unknown")
|
||||
logger.debug(f"Tool called: {tool_name}")
|
||||
|
||||
elif event.item.type == "tool_call_output_item":
|
||||
output = getattr(event.item, "output", "no output")
|
||||
logger.debug(f"Tool output: {output}")
|
||||
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
logger.debug(f"Agent updated: {event.new_agent.name}")
|
||||
|
||||
# Reset accumulated text for next request
|
||||
self._accumulated_text = ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming response: {e}")
|
||||
raise
|
||||
|
||||
async def _process_non_streaming_response(self, input_text: str):
|
||||
"""Process a non-streaming agent response.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
try:
|
||||
# Run the agent without streaming
|
||||
result: RunResult = await Runner.run(
|
||||
self._agent, input_text, context=self._session_config
|
||||
)
|
||||
|
||||
# Send the final output
|
||||
if result.final_output:
|
||||
await self.push_frame(LLMTextFrame(text=result.final_output))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in non-streaming response: {e}")
|
||||
raise
|
||||
|
||||
def _extract_message_text(self, item) -> str:
|
||||
"""Extract text from a message output item.
|
||||
|
||||
Args:
|
||||
item: The message output item from the agent.
|
||||
|
||||
Returns:
|
||||
The extracted text content.
|
||||
"""
|
||||
try:
|
||||
# Handle OpenAI Agents SDK MessageOutputItem format
|
||||
if hasattr(item, "raw_item") and hasattr(item.raw_item, "content"):
|
||||
content = item.raw_item.content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for content_part in content:
|
||||
if hasattr(content_part, "text"):
|
||||
text_parts.append(content_part.text)
|
||||
elif (
|
||||
isinstance(content_part, dict)
|
||||
and content_part.get("type") == "output_text"
|
||||
):
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
elif isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
# Handle direct content attribute
|
||||
elif hasattr(item, "content"):
|
||||
if isinstance(item.content, str):
|
||||
return item.content
|
||||
elif isinstance(item.content, list):
|
||||
# Extract text from content array
|
||||
text_parts = []
|
||||
for content_part in item.content:
|
||||
if isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
elif isinstance(content_part, str):
|
||||
text_parts.append(content_part)
|
||||
return "".join(text_parts)
|
||||
|
||||
# If no text content found, return empty string instead of str(item)
|
||||
logger.debug(f"No extractable text content found in item: {type(item)}")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from message item: {e}")
|
||||
return ""
|
||||
|
||||
async def add_tool(self, tool_function: ToolLike):
|
||||
"""Add a tool function to the agent.
|
||||
|
||||
Args:
|
||||
tool_function: A callable function or Tool object to add as a tool.
|
||||
"""
|
||||
if hasattr(self._agent, "tools"):
|
||||
# Cast to Any to handle the type variance issue
|
||||
tools_list: List[Any] = self._agent.tools
|
||||
tools_list.append(tool_function)
|
||||
tool_name = getattr(
|
||||
tool_function, "__name__", getattr(tool_function, "name", "unknown")
|
||||
)
|
||||
logger.info(f"Added tool {tool_name} to agent {self._agent.name}")
|
||||
|
||||
async def add_handoff_agent(self, agent: AgentLike):
|
||||
"""Add a handoff agent.
|
||||
|
||||
Args:
|
||||
agent: Another Agent instance or handoff object that this agent can hand off to.
|
||||
"""
|
||||
if hasattr(self._agent, "handoffs"):
|
||||
# Cast to Any to handle the type variance issue
|
||||
handoffs_list: List[Any] = self._agent.handoffs
|
||||
handoffs_list.append(agent)
|
||||
agent_name = getattr(agent, "name", "unknown")
|
||||
logger.info(f"Added handoff agent {agent_name} to agent {self._agent.name}")
|
||||
|
||||
def get_session_context(self) -> Dict[str, Any]:
|
||||
"""Get the current session context.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current session context.
|
||||
"""
|
||||
return self._session_config.copy()
|
||||
|
||||
def update_session_context(self, context: Dict[str, Any]):
|
||||
"""Update the session context.
|
||||
|
||||
Args:
|
||||
context: Dictionary of context updates to apply.
|
||||
"""
|
||||
self._session_config.update(context)
|
||||
logger.debug(f"Updated session context for agent {self._agent.name}")
|
||||
|
||||
|
||||
class OpenAIAgentUserContextAggregator(LLMUserContextAggregator):
|
||||
"""OpenAI Agent-specific user context aggregator.
|
||||
|
||||
Handles aggregation of user messages for OpenAI Agent services.
|
||||
Inherits all functionality from the base LLMUserContextAggregator.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAgentAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""OpenAI Agent-specific assistant context aggregator.
|
||||
|
||||
Handles aggregation of assistant messages for OpenAI Agent services,
|
||||
with specialized support for OpenAI's function calling format,
|
||||
tool usage tracking, and agent interaction management.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -478,7 +478,11 @@ class SmallWebRTCClient:
|
||||
self._screen_video_track = None
|
||||
self._audio_output_track = None
|
||||
self._video_output_track = None
|
||||
await self._callbacks.on_client_disconnected(self._webrtc_connection)
|
||||
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._closing:
|
||||
await self._callbacks.on_client_disconnected(self._webrtc_connection)
|
||||
|
||||
async def _handle_app_message(self, message: Any):
|
||||
"""Handle incoming application messages."""
|
||||
|
||||
@@ -138,7 +138,6 @@ class FastAPIWebsocketClient:
|
||||
):
|
||||
logger.warning("Closing already disconnected websocket!")
|
||||
self._closing = True
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect the WebSocket client."""
|
||||
@@ -152,8 +151,6 @@ class FastAPIWebsocketClient:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception while closing the websocket: {e}")
|
||||
finally:
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def trigger_client_disconnected(self):
|
||||
"""Trigger the client disconnected callback."""
|
||||
@@ -298,7 +295,10 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
await self._client.trigger_client_disconnected()
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._client.is_closing:
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
@@ -446,6 +446,9 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
"""Serialize and send a frame through the WebSocket."""
|
||||
if self._client.is_closing or not self._client.is_connected:
|
||||
return
|
||||
|
||||
if not self._params.serializer:
|
||||
return
|
||||
|
||||
|
||||
172
test_openai_agent.py
Normal file
172
test_openai_agent.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Simple test script for OpenAI Agent service."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock the OpenAI API key for testing
|
||||
os.environ["OPENAI_API_KEY"] = "test-key-for-testing"
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai_agent import OpenAIAgentService
|
||||
|
||||
|
||||
async def test_basic_functionality():
|
||||
"""Test basic OpenAI Agent service functionality."""
|
||||
print("🧪 Testing OpenAI Agent Service...")
|
||||
|
||||
# Create a simple weather tool for testing
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"The weather in {location} is sunny and 22°C."
|
||||
|
||||
try:
|
||||
# Create the service
|
||||
print("📋 Creating OpenAI Agent service...")
|
||||
service = OpenAIAgentService(
|
||||
name="Test Assistant",
|
||||
instructions="You are a helpful test assistant.",
|
||||
tools=[get_weather],
|
||||
api_key="test-key",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
print(f"✅ Service created successfully!")
|
||||
print(f" - Agent name: {service.agent.name}")
|
||||
print(f" - Model name: {service.model_name}")
|
||||
print(f" - Streaming enabled: {service._streaming}")
|
||||
|
||||
# Test basic configuration
|
||||
print("⚙️ Testing configuration updates...")
|
||||
service.update_agent_config(
|
||||
instructions="Updated test instructions",
|
||||
model_config={"model": "gpt-4o", "temperature": 0.5},
|
||||
)
|
||||
|
||||
print(f"✅ Configuration updated!")
|
||||
print(f" - New instructions: {service.agent.instructions}")
|
||||
print(f" - New model: {service.model_name}")
|
||||
|
||||
# Test session context
|
||||
print("💾 Testing session context...")
|
||||
service.update_session_context({"user_id": "test-user", "session": "test-session"})
|
||||
context = service.get_session_context()
|
||||
|
||||
print(f"✅ Session context managed!")
|
||||
print(f" - Context keys: {list(context.keys())}")
|
||||
|
||||
# Test adding tools
|
||||
print("🔧 Testing tool management...")
|
||||
|
||||
def get_time() -> str:
|
||||
"""Get current time."""
|
||||
return "The current time is 3:00 PM."
|
||||
|
||||
await service.add_tool(get_time)
|
||||
print(f"✅ Tool added successfully!")
|
||||
|
||||
print("\n🎉 All basic functionality tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_frame_processing():
|
||||
"""Test frame processing with mocked responses."""
|
||||
print("\n🔄 Testing frame processing...")
|
||||
|
||||
try:
|
||||
# Mock the Runner to avoid actual API calls
|
||||
with patch("pipecat.services.openai_agent.agent_service.Runner") as mock_runner:
|
||||
# Set up mock responses
|
||||
mock_stream_result = MagicMock()
|
||||
|
||||
# Mock stream events
|
||||
async def mock_stream_events():
|
||||
# Simulate streaming response
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="Hello "))
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="from "))
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="agent!"))
|
||||
|
||||
# Simulate completed message
|
||||
mock_item = MagicMock()
|
||||
mock_item.type = "message_output_item"
|
||||
mock_item.content = "Hello from agent!"
|
||||
yield MagicMock(type="run_item_stream_event", item=mock_item)
|
||||
|
||||
mock_stream_result.stream_events.return_value = mock_stream_events()
|
||||
mock_runner.run_streamed.return_value = mock_stream_result
|
||||
|
||||
# Create service with mocked runner
|
||||
service = OpenAIAgentService(
|
||||
name="Test Assistant",
|
||||
instructions="You are a helpful test assistant.",
|
||||
api_key="test-key",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Collect output frames
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
print(f" 📤 Frame: {type(frame).__name__}")
|
||||
if hasattr(frame, "text"):
|
||||
print(f" Text: '{frame.text}'")
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
print("📝 Processing text frame...")
|
||||
text_frame = TextFrame("Hello, how are you?")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait for async processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
print(f"✅ Frame processing completed!")
|
||||
print(f" - Generated {len(output_frames)} output frames")
|
||||
|
||||
# Check if we got expected frame types
|
||||
frame_types = [type(frame).__name__ for frame in output_frames]
|
||||
print(f" - Frame types: {frame_types}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Frame processing test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting OpenAI Agent Service Tests\n")
|
||||
|
||||
try:
|
||||
# Run basic functionality tests
|
||||
basic_test = await test_basic_functionality()
|
||||
|
||||
# Run frame processing tests
|
||||
frame_test = await test_frame_processing()
|
||||
|
||||
# Summary
|
||||
print(f"\n📊 Test Results:")
|
||||
print(f" - Basic functionality: {'✅ PASS' if basic_test else '❌ FAIL'}")
|
||||
print(f" - Frame processing: {'✅ PASS' if frame_test else '❌ FAIL'}")
|
||||
|
||||
if basic_test and frame_test:
|
||||
print(f"\n🎉 All tests passed! The OpenAI Agent service is working correctly.")
|
||||
else:
|
||||
print(f"\n⚠️ Some tests failed. Please check the output above.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test suite failed with error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
33
test_simple_agent.py
Normal file
33
test_simple_agent.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# Test the actual agents package API
|
||||
try:
|
||||
from agents import Agent, run
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
name="test-agent",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
print("✅ Agent created successfully!")
|
||||
print(f"Agent name: {agent.name}")
|
||||
|
||||
# Test a simple conversation
|
||||
async def test_agent():
|
||||
result = await run(agent, "Hello, how are you?")
|
||||
print(f"Agent response: {result}")
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_agent())
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
998
tests/test_get_llm_invocation_params.py
Normal file
998
tests/test_get_llm_invocation_params.py
Normal file
@@ -0,0 +1,998 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for LLM adapters' get_llm_invocation_params() method.
|
||||
|
||||
These tests focus specifically on the "messages" field generation for different adapters, ensuring:
|
||||
|
||||
For OpenAI adapter:
|
||||
1. LLMStandardMessage objects are passed through unchanged
|
||||
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
|
||||
3. Complex message structures (like multi-part content) are preserved
|
||||
4. System instructions are preserved throughout messages at any position
|
||||
|
||||
For Gemini adapter:
|
||||
1. LLMStandardMessage objects are converted to Gemini Content format
|
||||
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
|
||||
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
|
||||
4. System messages are extracted as system_instruction (without duplication)
|
||||
5. Single system instruction is converted to user message when no other messages exist
|
||||
6. Multiple system instructions: first extracted, later ones converted to user messages
|
||||
|
||||
For Anthropic adapter:
|
||||
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
|
||||
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
|
||||
For AWS Bedrock adapter:
|
||||
1. LLMStandardMessage objects are converted to AWS Bedrock format
|
||||
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from google.genai.types import Content, Part
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = OpenAILLMAdapter()
|
||||
|
||||
def test_standard_messages_passed_through_unchanged(self):
|
||||
"""Test that LLMStandardMessage objects are passed through unchanged to OpenAI params."""
|
||||
# Create standard messages (OpenAI format)
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify messages are passed through unchanged
|
||||
self.assertEqual(params["messages"], standard_messages)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# Verify content matches exactly
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
|
||||
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that OpenAI-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Gemini specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
self.adapter.create_llm_specific_message(
|
||||
{"role": "assistant", "content": "OpenAI specific response"}
|
||||
),
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only include standard messages and OpenAI-specific ones
|
||||
# (3 total: system, standard user, openai assistant)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# Verify the correct messages are included
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
self.assertEqual(params["messages"][1]["content"], "Standard user message")
|
||||
self.assertEqual(
|
||||
params["messages"][2], {"role": "assistant", "content": "OpenAI specific response"}
|
||||
)
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message content (like multi-part messages) is preserved."""
|
||||
# Create a message with complex content structure (text + image)
|
||||
complex_image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD..."},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with multiple text blocks
|
||||
multi_text_message = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me analyze this step by step:"},
|
||||
{"type": "text", "text": "1. First, I'll examine the visual elements"},
|
||||
{"type": "text", "text": "2. Then I'll provide my conclusions"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant that can analyze images."},
|
||||
complex_image_message,
|
||||
multi_text_message,
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify complex content is preserved
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][1], complex_image_message)
|
||||
self.assertEqual(params["messages"][2], multi_text_message)
|
||||
|
||||
# Verify the image message structure is maintained
|
||||
image_content = params["messages"][1]["content"]
|
||||
self.assertIsInstance(image_content, list)
|
||||
self.assertEqual(len(image_content), 2)
|
||||
self.assertEqual(image_content[0]["type"], "text")
|
||||
self.assertEqual(image_content[1]["type"], "image_url")
|
||||
|
||||
# Verify the multi-text message structure is maintained
|
||||
text_content = params["messages"][2]["content"]
|
||||
self.assertIsInstance(text_content, list)
|
||||
self.assertEqual(len(text_content), 3)
|
||||
for i, text_block in enumerate(text_content):
|
||||
self.assertEqual(text_block["type"], "text")
|
||||
self.assertEqual(text_content[0]["text"], "Let me analyze this step by step:")
|
||||
self.assertEqual(text_content[1]["text"], "1. First, I'll examine the visual elements")
|
||||
self.assertEqual(text_content[2]["text"], "2. Then I'll provide my conclusions")
|
||||
|
||||
def test_system_instructions_preserved_throughout_messages(self):
|
||||
"""Test that OpenAI adapter preserves system instructions sprinkled throughout messages."""
|
||||
# Create messages with system instructions at different positions
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."},
|
||||
{"role": "user", "content": "Tell me about Python."},
|
||||
{"role": "system", "content": "Use simple language."},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# OpenAI should preserve all messages unchanged, including multiple system messages
|
||||
self.assertEqual(len(params["messages"]), 7)
|
||||
|
||||
# Verify system messages are preserved at their original positions
|
||||
self.assertEqual(params["messages"][0]["role"], "system")
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
self.assertEqual(params["messages"][3]["role"], "system")
|
||||
self.assertEqual(params["messages"][3]["content"], "Remember to be concise.")
|
||||
|
||||
self.assertEqual(params["messages"][5]["role"], "system")
|
||||
self.assertEqual(params["messages"][5]["content"], "Use simple language.")
|
||||
|
||||
# Verify other messages remain unchanged
|
||||
self.assertEqual(params["messages"][1]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][4]["role"], "user")
|
||||
self.assertEqual(params["messages"][6]["role"], "assistant")
|
||||
|
||||
|
||||
class TestGeminiGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = GeminiLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_gemini_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to Gemini Content format."""
|
||||
# Create standard messages (OpenAI format)
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction is extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are converted to Gemini format (2 messages: user + model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertIsInstance(user_msg, Content)
|
||||
self.assertEqual(user_msg.role, "user")
|
||||
self.assertEqual(len(user_msg.parts), 1)
|
||||
self.assertEqual(user_msg.parts[0].text, "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant -> model)
|
||||
model_msg = params["messages"][1]
|
||||
self.assertIsInstance(model_msg, Content)
|
||||
self.assertEqual(model_msg.role, "model")
|
||||
self.assertEqual(len(model_msg.parts), 1)
|
||||
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Gemini-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific message"}
|
||||
),
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
self.adapter.create_llm_specific_message(
|
||||
Content(role="model", parts=[Part(text="Gemini specific response")]),
|
||||
),
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only include standard messages and Gemini-specific ones
|
||||
# (2 total: converted standard user + gemini model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify the correct messages are included
|
||||
self.assertEqual(params["messages"][0].role, "user")
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "Standard user message")
|
||||
|
||||
self.assertEqual(params["messages"][1].role, "model")
|
||||
self.assertEqual(params["messages"][1].parts[0].text, "Gemini specific response")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message content (like multi-part messages) is preserved and converted.
|
||||
|
||||
This test covers image, audio, and multi-text content conversion to Gemini format.
|
||||
"""
|
||||
# Create a message with complex content structure (text + image)
|
||||
# Using a minimal valid base64 image data
|
||||
complex_image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with multiple text blocks
|
||||
multi_text_message = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me analyze this step by step:"},
|
||||
{"type": "text", "text": "1. First, I'll examine the visual elements"},
|
||||
{"type": "text", "text": "2. Then I'll provide my conclusions"},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with audio input (text + audio)
|
||||
# Using a minimal valid base64 audio data (16 bytes of WAV header)
|
||||
audio_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you transcribe this audio?"},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=",
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can analyze images and audio.",
|
||||
},
|
||||
complex_image_message,
|
||||
multi_text_message,
|
||||
audio_message,
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(
|
||||
params["system_instruction"],
|
||||
"You are a helpful assistant that can analyze images and audio.",
|
||||
)
|
||||
|
||||
# Verify complex content is converted to Gemini format
|
||||
# Note: Gemini adapter may add system instruction back as user message in some cases
|
||||
self.assertGreaterEqual(len(params["messages"]), 3)
|
||||
|
||||
# Find the different message types
|
||||
user_with_image = None
|
||||
model_with_text = None
|
||||
user_with_audio = None
|
||||
|
||||
for msg in params["messages"]:
|
||||
if msg.role == "user" and len(msg.parts) == 2:
|
||||
# Check if it's image or audio based on the text content
|
||||
if hasattr(msg.parts[0], "text") and "image" in msg.parts[0].text:
|
||||
user_with_image = msg
|
||||
elif hasattr(msg.parts[0], "text") and "audio" in msg.parts[0].text:
|
||||
user_with_audio = msg
|
||||
elif msg.role == "model" and len(msg.parts) == 3:
|
||||
model_with_text = msg
|
||||
|
||||
# Verify the image message structure is converted properly
|
||||
self.assertIsNotNone(user_with_image, "Should have user message with image")
|
||||
self.assertEqual(len(user_with_image.parts), 2)
|
||||
|
||||
# First part should be text
|
||||
self.assertEqual(user_with_image.parts[0].text, "What's in this image?")
|
||||
|
||||
# Second part should be image data (converted to Blob)
|
||||
self.assertIsNotNone(user_with_image.parts[1].inline_data)
|
||||
self.assertEqual(user_with_image.parts[1].inline_data.mime_type, "image/jpeg")
|
||||
|
||||
# Verify the audio message structure is converted properly
|
||||
self.assertIsNotNone(user_with_audio, "Should have user message with audio")
|
||||
self.assertEqual(len(user_with_audio.parts), 2)
|
||||
|
||||
# First part should be text
|
||||
self.assertEqual(user_with_audio.parts[0].text, "Can you transcribe this audio?")
|
||||
|
||||
# Second part should be audio data (converted to Blob)
|
||||
self.assertIsNotNone(user_with_audio.parts[1].inline_data)
|
||||
self.assertEqual(user_with_audio.parts[1].inline_data.mime_type, "audio/wav")
|
||||
|
||||
# Verify the multi-text message structure is converted properly
|
||||
self.assertIsNotNone(model_with_text, "Should have model message with multi-text")
|
||||
self.assertEqual(len(model_with_text.parts), 3)
|
||||
|
||||
# All parts should be text
|
||||
expected_texts = [
|
||||
"Let me analyze this step by step:",
|
||||
"1. First, I'll examine the visual elements",
|
||||
"2. Then I'll provide my conclusions",
|
||||
]
|
||||
for i, expected_text in enumerate(expected_texts):
|
||||
self.assertEqual(model_with_text.parts[i].text, expected_text)
|
||||
|
||||
def test_single_system_instruction_converted_to_user(self):
|
||||
"""Test that when there's only a system instruction, it gets converted to user message."""
|
||||
# Create context with only a system message
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages=messages)
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System instruction should be extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# But since there are no other messages, it should also be added back as a user message
|
||||
self.assertEqual(len(params["messages"]), 1)
|
||||
self.assertEqual(params["messages"][0].role, "user")
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "You are a helpful assistant.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
# Create messages with multiple system instructions
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."},
|
||||
{"role": "user", "content": "Tell me about Python."},
|
||||
{"role": "system", "content": "Use simple language."},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages=messages)
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# First system instruction should be extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 6 messages (original 7 minus 1 system instruction that was extracted)
|
||||
self.assertEqual(len(params["messages"]), 6)
|
||||
|
||||
# Find the converted system messages (should be user role now)
|
||||
converted_system_messages = []
|
||||
for msg in params["messages"]:
|
||||
if msg.role == "user" and (
|
||||
msg.parts[0].text == "Remember to be concise."
|
||||
or msg.parts[0].text == "Use simple language."
|
||||
):
|
||||
converted_system_messages.append(msg.parts[0].text)
|
||||
|
||||
# Should have 2 converted system messages
|
||||
self.assertEqual(len(converted_system_messages), 2)
|
||||
self.assertIn("Remember to be concise.", converted_system_messages)
|
||||
self.assertIn("Use simple language.", converted_system_messages)
|
||||
|
||||
# Verify that regular user and assistant messages are preserved
|
||||
user_messages = [msg for msg in params["messages"] if msg.role == "user"]
|
||||
model_messages = [msg for msg in params["messages"] if msg.role == "model"]
|
||||
|
||||
# Should have 4 user messages: 2 original + 2 converted from system
|
||||
self.assertEqual(len(user_messages), 4)
|
||||
# Should have 2 model messages (converted from assistant)
|
||||
self.assertEqual(len(model_messages), 2)
|
||||
|
||||
|
||||
class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = AnthropicLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_anthropic_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to Anthropic MessageParam format."""
|
||||
# Create standard messages
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify system instruction is extracted
|
||||
self.assertEqual(params["system"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are in the params (2 messages after system extraction)
|
||||
self.assertIn("messages", params)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertEqual(user_msg["content"], "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant)
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Anthropic-specific messages are included and others are filtered out."""
|
||||
# Create anthropic-specific message content
|
||||
anthropic_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/jpeg", "data": "fake_data"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(anthropic_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + anthropic-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message: merged user message (standard + anthropic-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + anthropic text + anthropic image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "image")
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First user message"},
|
||||
{"role": "user", "content": "Second user message"},
|
||||
{"role": "user", "content": "Third user message"},
|
||||
{"role": "assistant", "content": "First assistant message"},
|
||||
{"role": "assistant", "content": "Second assistant message"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should have 2 messages after merging (1 user, 1 assistant)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check merged user message
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "First user message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
|
||||
|
||||
# Check merged assistant message
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 2)
|
||||
self.assertEqual(assistant_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
|
||||
self.assertEqual(assistant_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
|
||||
|
||||
def test_empty_text_converted_to_empty_placeholder(self):
|
||||
"""Test that empty text content is converted to "(empty)" string."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""}, # Empty string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""}, # Empty text in list content
|
||||
{"type": "text", "text": "Valid text"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Check that empty string content was converted
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["content"], "(empty)")
|
||||
|
||||
# Check that empty text in list content was converted
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message structures (text + image) are properly converted to Anthropic format."""
|
||||
# Create a complex message with both text and image content
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,fake_image_data"},
|
||||
},
|
||||
{"type": "text", "text": "Please describe it in detail."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
complex_message,
|
||||
{"role": "assistant", "content": "I can see the image clearly."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify complex message structure is preserved and converted
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
|
||||
# Note: Anthropic adapter reorders single images to come before text, as per Anthropic docs
|
||||
# Check image part (should be moved to first position and converted from image_url to image)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "image")
|
||||
self.assertIn("source", user_msg["content"][0])
|
||||
self.assertEqual(user_msg["content"][0]["source"]["type"], "base64")
|
||||
self.assertEqual(user_msg["content"][0]["source"]["media_type"], "image/jpeg")
|
||||
self.assertEqual(user_msg["content"][0]["source"]["data"], "fake_image_data")
|
||||
|
||||
# Check first text part (moved to second position)
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
|
||||
|
||||
# Check second text part (moved to third position)
|
||||
self.assertEqual(user_msg["content"][2]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."}, # Later system message
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# System instruction should be extracted from first message
|
||||
self.assertEqual(params["system"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 3 messages remaining (system message was removed, later system converted to user)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"], "Hello")
|
||||
self.assertEqual(params["messages"][1]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][1]["content"], "Hi there!")
|
||||
|
||||
# Later system message should be converted to user role
|
||||
self.assertEqual(params["messages"][2]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["content"], "Remember to be concise.")
|
||||
|
||||
def test_single_system_message_converted_to_user(self):
|
||||
"""Test that a single system message is converted to user role when no other messages exist."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# System should be NOT_GIVEN since we only have one message
|
||||
from anthropic import NOT_GIVEN
|
||||
|
||||
self.assertEqual(params["system"], NOT_GIVEN)
|
||||
|
||||
# Single system message should be converted to user role
|
||||
self.assertEqual(len(params["messages"]), 1)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
|
||||
class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = AWSBedrockLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_aws_bedrock_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to AWS Bedrock format."""
|
||||
# Create standard messages
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction is extracted (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are in the params (2 messages after system extraction)
|
||||
self.assertIn("messages", params)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user) - should be converted to AWS Bedrock format
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 1)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant) - should be converted to AWS Bedrock format
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 1)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that AWS-specific messages are included and others are filtered out."""
|
||||
# Create aws-specific message content (which is what AWS Bedrock uses)
|
||||
aws_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello"},
|
||||
{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(message=aws_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + aws-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message: merged user message (standard + aws-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + aws text + aws image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertIn("image", user_msg["content"][2])
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First user message"},
|
||||
{"role": "user", "content": "Second user message"},
|
||||
{"role": "user", "content": "Third user message"},
|
||||
{"role": "assistant", "content": "First assistant message"},
|
||||
{"role": "assistant", "content": "Second assistant message"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should have 2 messages after merging (1 user, 1 assistant)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check merged user message
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "First user message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
|
||||
|
||||
# Check merged assistant message
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 2)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
|
||||
|
||||
def test_empty_text_converted_to_empty_placeholder(self):
|
||||
"""Test that empty text content is converted to "(empty)" string."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""}, # Empty string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""}, # Empty text in list content
|
||||
{"type": "text", "text": "Valid text"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Check that empty string content was converted
|
||||
user_msg = params["messages"][0]
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "(empty)")
|
||||
|
||||
# Check that empty text in list content was converted
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message structures (text + image) are properly converted to AWS Bedrock format."""
|
||||
# Create a complex message with both text and image content
|
||||
# Use a valid base64 string for the image
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Please describe it in detail."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
complex_message,
|
||||
{"role": "assistant", "content": "I can see the image clearly."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify complex message structure is preserved and converted
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
|
||||
# Note: AWS Bedrock adapter reorders single images to come before text, like Anthropic
|
||||
# Check image part (should be moved to first position and converted from image_url to image)
|
||||
self.assertIn("image", user_msg["content"][0])
|
||||
self.assertEqual(user_msg["content"][0]["image"]["format"], "jpeg")
|
||||
self.assertIn("source", user_msg["content"][0]["image"])
|
||||
self.assertIn("bytes", user_msg["content"][0]["image"]["source"])
|
||||
|
||||
# Check first text part (moved to second position)
|
||||
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
|
||||
|
||||
# Check second text part (moved to third position)
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."}, # Later system message
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System instruction should be extracted from first message (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 3 messages remaining (system message was removed, later system converted to user)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
|
||||
self.assertEqual(params["messages"][1]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Hi there!")
|
||||
|
||||
# Later system message should be converted to user role
|
||||
self.assertEqual(params["messages"][2]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["content"][0]["text"], "Remember to be concise.")
|
||||
|
||||
def test_single_system_message_handling(self):
|
||||
"""Test that a single system message is extracted as system parameter and no messages remain."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System should be extracted (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# No messages should remain after system extraction
|
||||
self.assertEqual(len(params["messages"]), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
286
tests/test_openai_agent_service.py
Normal file
286
tests/test_openai_agent_service.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for OpenAI Agent service."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add src to path for testing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""Mock Agent for testing."""
|
||||
|
||||
def __init__(self, name="Test Agent", instructions="Test instructions"):
|
||||
self.name = name
|
||||
self.instructions = instructions
|
||||
self.tools = []
|
||||
self.handoffs = []
|
||||
|
||||
|
||||
class MockRunResult:
|
||||
"""Mock RunResult for testing."""
|
||||
|
||||
def __init__(self, final_output="Test response"):
|
||||
self.final_output = final_output
|
||||
|
||||
|
||||
class MockStreamEvent:
|
||||
"""Mock StreamEvent for testing."""
|
||||
|
||||
def __init__(self, event_type, data=None, item=None):
|
||||
self.type = event_type
|
||||
self.data = data
|
||||
self.item = item
|
||||
|
||||
|
||||
class MockMessageItem:
|
||||
"""Mock message item for testing."""
|
||||
|
||||
def __init__(self, content="Test content"):
|
||||
self.type = "message_output_item"
|
||||
self.content = content
|
||||
|
||||
|
||||
class MockRunner:
|
||||
"""Mock Runner for testing."""
|
||||
|
||||
@staticmethod
|
||||
async def run(agent, input_text, context=None):
|
||||
return MockRunResult("Mocked response")
|
||||
|
||||
@staticmethod
|
||||
def run_streamed(agent, input_text, context=None):
|
||||
class MockStreamResult:
|
||||
async def stream_events(self):
|
||||
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="Test "))
|
||||
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="response"))
|
||||
yield MockStreamEvent(
|
||||
"run_item_stream_event", item=MockMessageItem("Test response")
|
||||
)
|
||||
|
||||
return MockStreamResult()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_agents():
|
||||
"""Mock the OpenAI Agents SDK imports."""
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"agents": MagicMock(),
|
||||
"agents.stream_events": MagicMock(),
|
||||
"agents.result": MagicMock(),
|
||||
},
|
||||
):
|
||||
# Mock the classes and functions we need
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.return_value = MockAgent()
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.run = AsyncMock(return_value=MockRunResult())
|
||||
mock_runner.run_streamed = MagicMock(return_value=MockRunner.run_streamed(None, None))
|
||||
|
||||
with (
|
||||
patch("pipecat.services.openai_agent.agent_service.Agent", mock_agent),
|
||||
patch("pipecat.services.openai_agent.agent_service.Runner", mock_runner),
|
||||
):
|
||||
yield {
|
||||
"Agent": mock_agent,
|
||||
"Runner": mock_runner,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_init(mock_openai_agents):
|
||||
"""Test OpenAI Agent service initialization."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
|
||||
)
|
||||
|
||||
assert service.agent.name == "Test Agent"
|
||||
assert service._streaming is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_process_text_frame_streaming(mock_openai_agents):
|
||||
"""Test processing text frame with streaming enabled."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
|
||||
)
|
||||
|
||||
# Mock the push_frame method to capture output
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
text_frame = TextFrame("Hello, agent!")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait a bit for async processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that appropriate frames were generated
|
||||
assert len(output_frames) > 0
|
||||
assert any(isinstance(frame, LLMFullResponseStartFrame) for frame in output_frames)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_process_text_frame_non_streaming(mock_openai_agents):
|
||||
"""Test processing text frame with streaming disabled."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=False
|
||||
)
|
||||
|
||||
# Mock the push_frame method to capture output
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
text_frame = TextFrame("Hello, agent!")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait a bit for async processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that appropriate frames were generated
|
||||
assert len(output_frames) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_update_config(mock_openai_agents):
|
||||
"""Test updating agent configuration."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Update configuration
|
||||
service.update_agent_config(
|
||||
instructions="Updated instructions", model_config={"model": "gpt-4o", "temperature": 0.7}
|
||||
)
|
||||
|
||||
assert service.agent.instructions == "Updated instructions"
|
||||
assert service.agent.model_config["model"] == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_session_context(mock_openai_agents):
|
||||
"""Test session context management."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent",
|
||||
instructions="Test instructions",
|
||||
api_key="test-key",
|
||||
session_config={"user_id": "test-user"},
|
||||
)
|
||||
|
||||
# Get initial context
|
||||
context = service.get_session_context()
|
||||
assert context["user_id"] == "test-user"
|
||||
|
||||
# Update context
|
||||
service.update_session_context({"session_id": "test-session"})
|
||||
|
||||
updated_context = service.get_session_context()
|
||||
assert updated_context["user_id"] == "test-user"
|
||||
assert updated_context["session_id"] == "test-session"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_add_tools(mock_openai_agents):
|
||||
"""Test adding tools to the agent."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Define a test tool
|
||||
def test_tool():
|
||||
return "test result"
|
||||
|
||||
# Add the tool
|
||||
await service.add_tool(test_tool)
|
||||
|
||||
# Check if tool was added (this depends on the mock implementation)
|
||||
assert hasattr(service.agent, "tools")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_lifecycle(mock_openai_agents):
|
||||
"""Test service lifecycle methods."""
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Test start
|
||||
start_frame = StartFrame()
|
||||
await service.start(start_frame)
|
||||
|
||||
# Test cancel
|
||||
cancel_frame = CancelFrame()
|
||||
await service.cancel(cancel_frame)
|
||||
|
||||
# Test stop
|
||||
end_frame = EndFrame()
|
||||
await service.stop(end_frame)
|
||||
|
||||
|
||||
def test_openai_agent_service_import_error():
|
||||
"""Test that import error is handled gracefully."""
|
||||
# Mock the import to fail
|
||||
with patch.dict("sys.modules", {"agents": None}):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# This should trigger the import error
|
||||
import importlib
|
||||
|
||||
import pipecat.services.openai_agent.agent_service
|
||||
|
||||
importlib.reload(pipecat.services.openai_agent.agent_service)
|
||||
|
||||
assert "Missing module" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -196,10 +196,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal start_received
|
||||
start_received = True
|
||||
|
||||
@task.event_handler("on_pipeline_ended")
|
||||
async def on_pipeline_ended(task, frame: EndFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame: Frame):
|
||||
nonlocal end_received
|
||||
end_received = True
|
||||
end_received = isinstance(frame, EndFrame)
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
@@ -214,10 +214,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@task.event_handler("on_pipeline_stopped")
|
||||
async def on_pipeline_ended(task, frame: StopFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame: Frame):
|
||||
nonlocal stop_received
|
||||
stop_received = True
|
||||
stop_received = isinstance(frame, StopFrame)
|
||||
|
||||
await task.queue_frame(StopFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
@@ -441,10 +441,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def on_pipeline_started(task: PipelineTask, frame: StartFrame):
|
||||
await task.cancel()
|
||||
|
||||
@task.event_handler("on_pipeline_cancelled")
|
||||
async def on_pipeline_cancelled(task: PipelineTask, frame: CancelFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
cancelled = isinstance(frame, CancelFrame)
|
||||
|
||||
try:
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
261
tests/test_run_inference.py
Normal file
261
tests/test_run_inference.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from anthropic import NOT_GIVEN
|
||||
from openai import NotGiven
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response."""
|
||||
# Create service with mocked client
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
]
|
||||
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
|
||||
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello! How can I help you today?"
|
||||
service._client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
service._client.chat.completions.create.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
messages=test_messages,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_client_exception():
|
||||
"""Test that exceptions from the client are propagated."""
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
|
||||
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.chat.completions.create.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for Anthropic."""
|
||||
# Create service with mocked client
|
||||
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
test_system = "You are a helpful assistant"
|
||||
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
|
||||
messages=test_messages, system=test_system, tools=[]
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Hello! How can I help you today?"
|
||||
service._client.messages.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(
|
||||
mock_context, enable_prompt_caching=False
|
||||
)
|
||||
service._client.messages.create.assert_called_once_with(
|
||||
model="claude-3-sonnet-20240229",
|
||||
messages=test_messages,
|
||||
system=test_system,
|
||||
max_tokens=8192,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_client_exception():
|
||||
"""Test that exceptions from the Anthropic client are propagated."""
|
||||
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
|
||||
messages=[], system="Test system", tools=[]
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.messages.create.side_effect = Exception("Anthropic API Error")
|
||||
|
||||
with pytest.raises(Exception, match="Anthropic API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for Google."""
|
||||
# Create service with mocked client
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
test_system = "You are a helpful assistant"
|
||||
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
|
||||
messages=test_messages, system_instruction=test_system, tools=NotGiven()
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [MagicMock()]
|
||||
mock_response.candidates[0].content = MagicMock()
|
||||
mock_response.candidates[0].content.parts = [MagicMock()]
|
||||
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
service._client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_client_exception():
|
||||
"""Test that exceptions from the Google client are propagated."""
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
|
||||
messages=[], system_instruction="Test system", tools=NotGiven()
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(
|
||||
side_effect=Exception("Google API Error")
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Google API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
|
||||
# Create service and patch the session client method
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
|
||||
test_system = [{"text": "You are a helpful assistant"}]
|
||||
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
|
||||
messages=test_messages, system=test_system, tools=[], tool_choice=None
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock the client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
|
||||
}
|
||||
mock_client.converse.return_value = mock_response
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
async def mock_client_cm(*args, **kwargs):
|
||||
return mock_client
|
||||
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
mock_client.converse.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_client_exception():
|
||||
"""Test that exceptions from the AWS Bedrock client are propagated."""
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
|
||||
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock AWS client to raise exception
|
||||
mock_client = AsyncMock()
|
||||
mock_client.converse.side_effect = Exception("Bedrock API Error")
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
with pytest.raises(Exception, match="Bedrock API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
Reference in New Issue
Block a user