Compare commits
80 Commits
hush/delay
...
hush/openA
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29d4a56663 | ||
|
|
373a09ecd6 | ||
|
|
07f54c48f3 | ||
|
|
c8a3d65aa4 | ||
|
|
50a2a0dc86 | ||
|
|
0421d97954 | ||
|
|
54c8f336c3 | ||
|
|
b086fbafe6 | ||
|
|
cca90791c4 | ||
|
|
f2a5d408de | ||
|
|
044c6eba46 | ||
|
|
db71089f5e | ||
|
|
f861f5066f | ||
|
|
81cede2c60 | ||
|
|
7603203230 | ||
|
|
8569b61598 | ||
|
|
fe42187dc1 | ||
|
|
999e88c942 | ||
|
|
c04df2f28b | ||
|
|
100ef0ab5c | ||
|
|
42886d7105 | ||
|
|
22cbba002a | ||
|
|
c873798ce5 | ||
|
|
d8cd28bb8b | ||
|
|
c2df6c8aee | ||
|
|
82478be861 | ||
|
|
0f2b7bc01b | ||
|
|
1b2a5df017 | ||
|
|
2f496ac74f | ||
|
|
22633a63b0 | ||
|
|
e5ed0424e4 | ||
|
|
786387722a | ||
|
|
9f82c6b4a4 | ||
|
|
99cfcb1d4e | ||
|
|
d595676436 | ||
|
|
0190812ee8 | ||
|
|
2a24061bbb | ||
|
|
89f7e7d199 | ||
|
|
384814e640 | ||
|
|
ab4364b833 | ||
|
|
fafdadad3c | ||
|
|
05dc2fa916 | ||
|
|
0c30cc6ea6 | ||
|
|
c26d336e34 | ||
|
|
37b6198787 | ||
|
|
3c271da94c | ||
|
|
be28d3f93b | ||
|
|
d2f210e960 | ||
|
|
57add41971 | ||
|
|
74b38b59d6 | ||
|
|
dac58deffc | ||
|
|
aff11f5121 | ||
|
|
a4023d3915 | ||
|
|
d6543d244d | ||
|
|
fafcd79870 | ||
|
|
6a717fbbd1 | ||
|
|
9b3f6927c2 | ||
|
|
0b21f8a6bd | ||
|
|
8249b014f0 | ||
|
|
9d9f10ae0e | ||
|
|
e27b23694d | ||
|
|
66ce5fe6bd | ||
|
|
a9b53dc800 | ||
|
|
818352a300 | ||
|
|
3e9fc7be19 | ||
|
|
a2e76bcad8 | ||
|
|
8e8e42717b | ||
|
|
b31322e38e | ||
|
|
fedb8a201f | ||
|
|
8ccd220a60 | ||
|
|
fe79de8f27 | ||
|
|
176573c342 | ||
|
|
75f9914f49 | ||
|
|
f4d6715e32 | ||
|
|
7366b1aee0 | ||
|
|
4699ee8d86 | ||
|
|
e3597801d4 | ||
|
|
2ee481d541 | ||
|
|
48b3ad8f8f | ||
|
|
8bbdc7c8d1 |
285
AGENTS.md
Normal file
285
AGENTS.md
Normal file
@@ -0,0 +1,285 @@
|
||||
# AGENTS.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
Pipecat is an open-source Python framework for building real-time voice and multimodal conversational AI agents. The codebase is organized around a pipeline architecture where data flows through connected services (STT → LLM → TTS).
|
||||
|
||||
## Development Environment Setup
|
||||
|
||||
### Prerequisites
|
||||
- **Minimum Python Version:** 3.10
|
||||
- **Recommended Python Version:** 3.12
|
||||
- **Package Manager:** uv (recommended) or pip
|
||||
|
||||
### Setup Commands
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/pipecat-ai/pipecat.git
|
||||
cd pipecat
|
||||
|
||||
# Install dependencies with uv (recommended)
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra ultravox
|
||||
|
||||
# Or with pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Install pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Set up environment variables
|
||||
cp env.example .env
|
||||
```
|
||||
|
||||
## Build and Test Commands
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest tests/test_name.py
|
||||
|
||||
# Run tests with coverage
|
||||
uv run pytest --cov=pipecat --cov-report=html
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Format code (required before commits)
|
||||
uv run ruff format
|
||||
|
||||
# Lint code
|
||||
uv run ruff check
|
||||
|
||||
# Type checking
|
||||
uv run mypy src/pipecat
|
||||
|
||||
# Run pre-commit checks manually
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Documentation
|
||||
```bash
|
||||
# Build API documentation
|
||||
cd docs/api
|
||||
./build-docs.sh
|
||||
|
||||
# Build docs manually
|
||||
sphinx-build -b html . _build/html -W --keep-going
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Python Standards
|
||||
- **Formatting:** Strict PEP 8 via Ruff
|
||||
- **Docstrings:** Google-style format
|
||||
- **Type Hints:** Required for all public APIs
|
||||
- **Import Organization:** Automated via Ruff
|
||||
|
||||
### Docstring Conventions
|
||||
- **Classes:** Describe purpose + `__init__` with complete `Args:` section
|
||||
- **Dataclasses:** Use `Parameters:` section, no `__init__` docstring
|
||||
- **Methods:** Include `Args:` and `Returns:` sections
|
||||
- **Properties:** Must have `Returns:` section
|
||||
- **Examples:** Use `Examples:` section with `::` syntax
|
||||
|
||||
### File Organization
|
||||
```
|
||||
src/pipecat/ # Main package
|
||||
├── processors/ # Frame processors
|
||||
├── services/ # AI service integrations
|
||||
├── transports/ # Communication layers
|
||||
├── frames/ # Data frame definitions
|
||||
└── pipeline/ # Pipeline orchestration
|
||||
|
||||
examples/foundational/ # Step-by-step tutorials
|
||||
tests/ # Test suite
|
||||
```
|
||||
|
||||
## Testing Instructions
|
||||
|
||||
### Test Structure
|
||||
- **Unit Tests:** Test individual components in isolation
|
||||
- **Integration Tests:** Test service interactions
|
||||
- **Example Tests:** Validate foundational examples work
|
||||
|
||||
### Adding Tests
|
||||
```bash
|
||||
# Test naming convention
|
||||
test_<component>_<functionality>.py
|
||||
|
||||
# Run specific test pattern
|
||||
uv run pytest -k "test_pipeline"
|
||||
|
||||
# Run with debugging
|
||||
uv run pytest -s -vv tests/test_name.py::test_function
|
||||
```
|
||||
|
||||
### Pre-commit Requirements
|
||||
All commits must pass:
|
||||
- Ruff formatting
|
||||
- Ruff linting
|
||||
- Type checking
|
||||
- Basic test suite
|
||||
|
||||
## Dependency Management
|
||||
|
||||
### Using uv (Recommended)
|
||||
```bash
|
||||
# Add runtime dependency
|
||||
uv add package-name
|
||||
|
||||
# Add optional dependency
|
||||
uv add --optional service package-name
|
||||
|
||||
# Add development dependency
|
||||
uv add --group dev package-name
|
||||
|
||||
# Update lockfile
|
||||
uv lock
|
||||
|
||||
# Sync dependencies
|
||||
uv sync
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
- **Always commit both `pyproject.toml` and `uv.lock` together**
|
||||
- **Never manually edit `uv.lock`** - it's auto-generated
|
||||
- **Use extras for optional service dependencies** (e.g., `[openai]`, `[cartesia]`)
|
||||
|
||||
## Project Structure Guidelines
|
||||
|
||||
### Service Integration
|
||||
When adding new AI services:
|
||||
1. Create service class in `src/pipecat/services/<provider>/`
|
||||
2. Follow existing patterns (e.g., STTService, LLMService)
|
||||
3. Add to appropriate extras in `pyproject.toml`
|
||||
4. Include tests in `tests/`
|
||||
5. Add documentation examples
|
||||
|
||||
### Frame Processing
|
||||
For custom processors:
|
||||
1. Inherit from `FrameProcessor`
|
||||
2. Implement `process_frame()` method. ALWAYS explicitly call `await super().process_frame(frame, direction)` at the top of this method.
|
||||
3. Handle frame direction (FrameDirection.UPSTREAM/DOWNSTREAM)
|
||||
4. Add proper type hints and docstrings
|
||||
|
||||
### Transport Implementation
|
||||
For new transport layers:
|
||||
1. Inherit from `BaseTransport`
|
||||
2. Implement required abstract methods
|
||||
3. Handle connection lifecycle
|
||||
4. Support both input and output streams
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### API Keys
|
||||
- **Never commit API keys** to the repository
|
||||
- **Use environment variables** for all secrets
|
||||
- **Reference `env.example`** for required variables
|
||||
- **Use `.env` files** for local development
|
||||
|
||||
### Input Validation
|
||||
- **Validate all external inputs** (audio, text, API responses)
|
||||
- **Sanitize user data** before processing
|
||||
- **Handle rate limiting** for external services
|
||||
- **Implement proper timeout handling**
|
||||
|
||||
## Performance Guidelines
|
||||
|
||||
### Memory Management
|
||||
- **Clean up resources** in transport disconnection handlers
|
||||
- **Use async context managers** for service connections
|
||||
- **Implement proper frame lifecycle** management
|
||||
|
||||
### Latency Optimization
|
||||
- **Choose appropriate STT services** for latency requirements
|
||||
- **Use streaming TTS** when possible
|
||||
- **Implement connection pooling** for HTTP services
|
||||
- **Consider WebRTC** for real-time applications
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Error Handling
|
||||
```python
|
||||
@transport.event_handler("on_error")
|
||||
async def on_error(transport, error):
|
||||
logger.error(f"Transport error: {error}")
|
||||
|
||||
# Shutdown the pipeline
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
```
|
||||
|
||||
### Service Configuration
|
||||
```python
|
||||
# Use environment variables for configuration
|
||||
service = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY", ""),
|
||||
model="gpt-4o",
|
||||
params={"temperature": 0.7}
|
||||
)
|
||||
```
|
||||
|
||||
### Pipeline Assembly
|
||||
```python
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt_service,
|
||||
context_aggregator.user(),
|
||||
llm_service,
|
||||
tts_service,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
```
|
||||
|
||||
## Commit and PR Guidelines
|
||||
|
||||
### Commit Message Format
|
||||
```
|
||||
<type>(<scope>): <description>
|
||||
|
||||
[optional body]
|
||||
|
||||
[optional footer]
|
||||
```
|
||||
|
||||
Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
|
||||
|
||||
### PR Requirements
|
||||
- **All tests must pass**
|
||||
- **Code must be properly formatted** (Ruff)
|
||||
- **Include appropriate tests** for new functionality
|
||||
- **Update documentation** if needed
|
||||
- **Reference related issues** in description
|
||||
|
||||
### Review Process
|
||||
1. Automated checks must pass
|
||||
2. Manual code review by maintainers
|
||||
3. Documentation review for user-facing changes
|
||||
4. Integration testing for service additions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
- **Import errors:** Run `uv sync` to ensure dependencies are installed
|
||||
- **Test failures:** Check environment variables in `.env`
|
||||
- **Format errors:** Run `uv run ruff format` before committing
|
||||
- **Type errors:** Ensure all public methods have type hints
|
||||
|
||||
### Development Tips
|
||||
- **Use foundational examples** as starting points for testing
|
||||
- **Check existing services** for integration patterns
|
||||
- **Run tests frequently** during development
|
||||
- **Use IDE integration** for Ruff formatting
|
||||
|
||||
### Getting Help
|
||||
- **Documentation:** [docs.pipecat.ai](https://docs.pipecat.ai)
|
||||
- **Issues:** [GitHub Issues](https://github.com/pipecat-ai/pipecat/issues)
|
||||
93
CHANGELOG.md
93
CHANGELOG.md
@@ -9,11 +9,97 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
|
||||
fired when the pipeline is done running. This can be the result of a
|
||||
`StopFrame`, `CancelFrame` or `EndFrame`.
|
||||
|
||||
```python
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
...
|
||||
```
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
|
||||
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
|
||||
instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
|
||||
when a bot speaks and user input is blocked.
|
||||
|
||||
- Fixed a `FastAPIWebsocketTransport` and `SmallWebRTCTransport` issue where
|
||||
`on_client_disconnected` would be triggered when the bot ends the
|
||||
conversation. That is, `on_client_disconnected` should only be triggered when
|
||||
the remote client actually disconnects.
|
||||
|
||||
- Fixed an issue in `HeyGenVideoService` where the `BotStartedSpeakingFrame`
|
||||
was blocked from moving through the Pipeline.
|
||||
|
||||
## [0.0.85] - 2025-09-12
|
||||
|
||||
### Added
|
||||
|
||||
- `AzureSTTService` now pushes interim transcriptions.
|
||||
|
||||
- Added `voice_cloning_key` to `GoogleTTSService` to support custom cloned
|
||||
voices.
|
||||
|
||||
- Added `speaking_rate` to `GoogleTTSService.InputParams` to control the
|
||||
speaking rate.
|
||||
|
||||
- Added a `speed` arg to `OpenAITTSService` to control the speed of the voice
|
||||
response.
|
||||
|
||||
- Added `FrameProcessor.push_interruption_task_frame_and_wait()`. Use this
|
||||
method to programatically interrupt the bot from any part of the
|
||||
pipeline. This guarantees that all the processors in the pipeline are
|
||||
interrupted in order (from upstream to downstream). Internally, this works by
|
||||
first pushing an `InterruptionTaskFrame` upstream until it reaches the
|
||||
pipeline task. The pipeline task then generates an `InterruptionFrame`, which
|
||||
flows downstream through all processors. Once the `InterruptionFrame` has
|
||||
reaches the processor waiting for the interruption, the function returns and
|
||||
execution continues after the call. Think of it as sending an upstream request
|
||||
for interruption and waiting until the acknowledgment flows back downstream.
|
||||
|
||||
- Added new base `TaskFrame` (which is a system frame). This is the base class
|
||||
for all task frames (`EndTaskFrame`, `CancelTaskFrame`, etc.) that are meant
|
||||
to be pushed upstream to reach the pipeline task.
|
||||
|
||||
- Expanded support for universal `LLMContext` to the AWS Bedrock LLM service.
|
||||
Using the universal `LLMContext` and associated `LLMContextAggregatorPair` is
|
||||
a pre-requisite for using `LLMSwitcher` to switch between LLMs at runtime.
|
||||
|
||||
- Added new fields to the development runner's `parse_telephony_websocket`
|
||||
method in support of providing dynamic data to a bot.
|
||||
|
||||
- Twilio: Added a new `body` parameter, which parses the websocket message
|
||||
for `customParameters`. Provide data via the `Parameter` nouns in your
|
||||
TwiML to use this feature.
|
||||
- Telnyx & Exotel: Both providers make the `to` and `from` phone numbers
|
||||
available in the websocket messages. You can now access these numbers as
|
||||
`call_data["to"]` and `call_data["from"]`.
|
||||
|
||||
Note: Each telephony provider offers different features. Refer to the
|
||||
corresponding example in `pipecat-examples` to see how to pass custom data
|
||||
to your bot.
|
||||
|
||||
- Added `body` to the `WebsocketRunnerArguments` as an optional parameter.
|
||||
Custom `body` information can be passed from the server into the bot file via
|
||||
the `bot()` method using this new parameter.
|
||||
|
||||
- Added video streaming support to `LiveKitTransport`.
|
||||
|
||||
- Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide
|
||||
access to OpenAI Realtime.
|
||||
|
||||
### Changed
|
||||
|
||||
- `pipeline.tests.utils.run_test()` now allows passing `PipelineParams` instead
|
||||
of individual parameters.
|
||||
|
||||
### Removed
|
||||
|
||||
- Remove `VisionImageRawFrame` in favor of context frames (`LLMContextFrame` or
|
||||
@@ -21,6 +107,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `BotInterruptionFrame` is now deprecated, use `InterruptionTaskFrame` instead.
|
||||
|
||||
- `StartInterruptionFrame` is now deprected, use `InterruptionFrame` instead.
|
||||
|
||||
- Deprecate `VisionImageFrameAggregator` because `VisionImageRawFrame` has been
|
||||
removed. See the `12*` examples for the new recommended replacement pattern.
|
||||
|
||||
@@ -33,6 +123,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that caused incorrect detection of when
|
||||
the bot stopped talking while using an audio mixer.
|
||||
|
||||
- Fixed a `LiveKitTransport` issue where RTVI messages were not properly
|
||||
encoded.
|
||||
|
||||
|
||||
23
README.md
23
README.md
@@ -153,7 +153,11 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
2. Install development and testing dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --group dev --all-extras --no-extra gstreamer --no-extra krisp --no-extra local
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra ultravox # (ultravox not fully supported on macOS)
|
||||
```
|
||||
|
||||
3. Install the git pre-commit hooks:
|
||||
@@ -162,23 +166,6 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Python 3.13+ Compatibility
|
||||
|
||||
Some features require PyTorch, which doesn't yet support Python 3.13+. Install using:
|
||||
|
||||
```bash
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra local-smart-turn \
|
||||
--no-extra mlx-whisper \
|
||||
--no-extra moondream \
|
||||
--no-extra ultravox
|
||||
```
|
||||
|
||||
> **Tip:** For full compatibility, use Python 3.12: `uv python pin 3.12`
|
||||
|
||||
> **Note**: Some extras (local, gstreamer) require system dependencies. See documentation if you encounter build errors.
|
||||
|
||||
### Running tests
|
||||
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -115,7 +115,7 @@ async def main():
|
||||
|
||||
await task.queue_frames(
|
||||
[
|
||||
BotInterruptionFrame(),
|
||||
InterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(
|
||||
user_id=participant_id,
|
||||
|
||||
@@ -36,7 +36,6 @@ load_dotenv(override=True)
|
||||
audiobuffer = AudioBufferProcessor(
|
||||
num_channels=2, # 1 for mono, 2 for stereo (user left, bot right)
|
||||
enable_turn_audio=False, # Enable per-turn audio recording
|
||||
user_continuous_stream=True, # User has continuous audio stream
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -97,7 +97,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
@stt.event_handler("on_speech_started")
|
||||
async def on_speech_started(stt, *args, **kwargs):
|
||||
await task.queue_frames([StartInterruptionFrame(), UserStartedSpeakingFrame()])
|
||||
await task.queue_frames([InterruptionFrame(), UserStartedSpeakingFrame()])
|
||||
|
||||
@stt.event_handler("on_utterance_end")
|
||||
async def on_utterance_end(stt, *args, **kwargs):
|
||||
|
||||
@@ -16,10 +16,10 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -181,9 +181,7 @@ class TranscriptionContextFixup(FrameProcessor):
|
||||
|
||||
if isinstance(frame, MagicDemoTranscriptionFrame):
|
||||
self._transcript = frame.text
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(
|
||||
frame, StartInterruptionFrame
|
||||
):
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, InterruptionFrame):
|
||||
self.swap_user_audio()
|
||||
self.add_transcript_back_to_inference_output()
|
||||
self._transcript = ""
|
||||
|
||||
@@ -13,6 +13,7 @@ from loguru import logger
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -21,10 +22,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
@@ -73,14 +71,14 @@ class UserImageProcessor(FrameProcessor):
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.request and frame.request.context:
|
||||
# Note: AWS Bedrock does not yet support the universal LLMContext
|
||||
context = OpenAILLMContext()
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = OpenAILLMContextFrame(context)
|
||||
frame = LLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -121,6 +119,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
aws = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
# Note: usually, prefer providing latency="optimized" param.
|
||||
# Here we can't because AWS Bedrock doesn't support it for Claude 3.7,
|
||||
# which we need for image input.
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
get_transport_client_id,
|
||||
maybe_capture_participant_camera,
|
||||
)
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# Global variable to store the client ID
|
||||
client_id = ""
|
||||
|
||||
|
||||
async def get_weather(params: FunctionCallParams):
|
||||
location = params.arguments["location"]
|
||||
await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(params: FunctionCallParams):
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={client_id}, question={question}")
|
||||
|
||||
# Request the image frame
|
||||
await params.llm.request_image_frame(
|
||||
user_id=client_id,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
text_content=question,
|
||||
)
|
||||
|
||||
# Wait a short time for the frame to be processed
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a result to complete the function call
|
||||
await params.result_callback(
|
||||
f"I've captured an image from your camera and I'm analyzing what you asked about: {question}"
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
# Note: usually, prefer providing latency="optimized" param.
|
||||
# Here we can't because AWS Bedrock doesn't support it for Claude 3.7,
|
||||
# which we need for image input.
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8),
|
||||
)
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
get_image_function = FunctionSchema(
|
||||
name="get_image",
|
||||
description="Get an image from the video stream.",
|
||||
properties={
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question that the user is asking about the image.",
|
||||
}
|
||||
},
|
||||
required=["question"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, get_image_function])
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
|
||||
If you need to use a tool, simply use the tool. Do not tell the user the tool you are using. Be brief and concise.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Start the conversation by introducing yourself."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
|
||||
await maybe_capture_participant_camera(transport, client)
|
||||
|
||||
global client_id
|
||||
client_id = get_transport_client_id(transport, client)
|
||||
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -22,7 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioNoiseReduction,
|
||||
@@ -31,7 +31,6 @@ from pipecat.services.openai_realtime_beta import (
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime_beta.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -114,18 +113,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
)
|
||||
),
|
||||
output_modalities=["text"],
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
modalities=["text"],
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
input_audio_noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -144,7 +144,7 @@ class OutputGate(FrameProcessor):
|
||||
await self._start()
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._frames_buffer = []
|
||||
self.close_gate()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -232,7 +232,7 @@ class TurnDetectionLLM(Pipeline):
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, InterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
|
||||
@@ -18,9 +18,9 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -347,7 +347,7 @@ class OutputGate(FrameProcessor):
|
||||
await self._start()
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._frames_buffer = []
|
||||
self.close_gate()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -426,7 +426,7 @@ class TurnDetectionLLM(Pipeline):
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, InterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
|
||||
@@ -20,10 +20,10 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -570,7 +570,7 @@ class OutputGate(FrameProcessor):
|
||||
await self._start()
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._frames_buffer = []
|
||||
self.close_gate()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -15,8 +15,8 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class CustomObserver(BaseObserver):
|
||||
"""Observer to log interruptions and bot speaking events to the console.
|
||||
|
||||
Logs all frame instances of:
|
||||
- StartInterruptionFrame
|
||||
- InterruptionFrame
|
||||
- BotStartedSpeakingFrame
|
||||
- BotStoppedSpeakingFrame
|
||||
|
||||
@@ -69,7 +69,7 @@ class CustomObserver(BaseObserver):
|
||||
# Create direction arrow
|
||||
arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←"
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame) and isinstance(src, BaseOutputTransport):
|
||||
if isinstance(frame, InterruptionFrame) and isinstance(src, BaseOutputTransport):
|
||||
logger.info(f"⚡ INTERRUPTION START: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
logger.info(f"🤖 BOT START SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
|
||||
@@ -11,7 +11,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
@@ -31,20 +31,7 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
load_dotenv(override=True)
|
||||
|
||||
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
|
||||
# to the path where the smart-turn repo is cloned.
|
||||
#
|
||||
# Example setup:
|
||||
#
|
||||
# # Git LFS (Large File Storage)
|
||||
# brew install git-lfs
|
||||
# # Hugging Face uses LFS to store large model files, including .mlpackage
|
||||
# git lfs install
|
||||
# # Clone the repo with the smart_turn_classifier.mlpackage
|
||||
# git clone https://huggingface.co/pipecat-ai/smart-turn-v2
|
||||
#
|
||||
# Then set the env variable:
|
||||
# export LOCAL_SMART_TURN_MODEL_PATH=./smart-turn
|
||||
# or add it to your .env file
|
||||
# to the Smart Turn v3 ONNX model file.
|
||||
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -55,7 +42,7 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
@@ -63,7 +50,7 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
@@ -71,7 +58,7 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
|
||||
205
examples/foundational/45-openai-agent-basic.py
Normal file
205
examples/foundational/45-openai-agent-basic.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Basic OpenAI Agent service example.
|
||||
|
||||
This example demonstrates how to use the OpenAI Agents SDK within a Pipecat
|
||||
pipeline to create an interactive agent with tool calling capabilities.
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key
|
||||
- OpenAI Agents SDK: pip install openai-agents
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, List
|
||||
|
||||
# Import agents SDK for tools and agent creation
|
||||
from agents import Agent, function_tool
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.frames.frames import LLMRunFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Transport configuration
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the current weather for a location.
|
||||
|
||||
Args:
|
||||
location: The location to get weather for
|
||||
|
||||
Returns:
|
||||
A weather description string
|
||||
"""
|
||||
# Mock weather data - in real usage, integrate with weather API
|
||||
weather_data = {
|
||||
"San Francisco": "Foggy, 65°F",
|
||||
"New York": "Sunny, 72°F",
|
||||
"London": "Rainy, 59°F",
|
||||
"Tokyo": "Partly cloudy, 68°F",
|
||||
}
|
||||
return weather_data.get(location, f"Weather data not available for {location}")
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact.
|
||||
|
||||
Returns:
|
||||
A random fact string
|
||||
"""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
|
||||
def get_random_fact_tool():
|
||||
"""Example tool function for random facts."""
|
||||
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact.
|
||||
|
||||
Returns:
|
||||
A random fact string.
|
||||
"""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"A group of flamingos is called a 'flamboyance'.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
return get_random_fact
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting OpenAI Agent bot")
|
||||
|
||||
# Set up STT for speech recognition
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
model="nova-2",
|
||||
)
|
||||
|
||||
# Set up TTS for voice output
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY", ""),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Create tools for the agent
|
||||
tools: list[Any] = [
|
||||
get_weather,
|
||||
get_random_fact,
|
||||
]
|
||||
|
||||
# Create the agent with tools
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="""You are a helpful assistant with access to weather information and random facts.
|
||||
You can:
|
||||
- Check weather for any location using the get_weather tool
|
||||
- Share interesting facts using the get_random_fact tool
|
||||
- Have natural conversations
|
||||
|
||||
Be friendly, informative, and engaging in your responses.""",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Initialize the OpenAI Agent service with the pre-configured agent
|
||||
agent_service = OpenAIAgentService(
|
||||
agent=agent,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Set up conversation context with initial system message
|
||||
messages: List[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant with access to weather information and random facts. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = agent_service.create_context_aggregator(context)
|
||||
|
||||
# Create the processing pipeline with context aggregators
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # Speech to text
|
||||
context_aggregator.user(), # User responses
|
||||
agent_service, # OpenAI Agent processing
|
||||
tts, # Text to speech
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Send an initial greeting when client connects
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected, sending greeting")
|
||||
# Kick off the conversation by adding system message and running LLM
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
276
examples/foundational/46-openai-agent-handoffs.py
Normal file
276
examples/foundational/46-openai-agent-handoffs.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Advanced OpenAI Agent service example with handoffs.
|
||||
|
||||
This example demonstrates how to use multiple agents with handoffs in the
|
||||
OpenAI Agents SDK within a Pipecat pipeline, showcasing agent orchestration
|
||||
and specialization.
|
||||
|
||||
Requirements:
|
||||
- OpenAI API key
|
||||
- OpenAI Agents SDK: pip install openai-agents
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.frames.frames import LLMRunFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Transport configuration
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_out_enabled=True, audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
def create_weather_tools():
|
||||
"""Create weather-related tools."""
|
||||
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get current weather for a location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "snowy", "windy"]
|
||||
temp = random.randint(-10, 35)
|
||||
condition = random.choice(conditions)
|
||||
return f"The weather in {location} is {condition} with a temperature of {temp}°C."
|
||||
|
||||
def get_forecast(location: str, days: int = 3) -> str:
|
||||
"""Get weather forecast for multiple days."""
|
||||
forecast = []
|
||||
for i in range(days):
|
||||
conditions = ["sunny", "cloudy", "rainy", "snowy"]
|
||||
temp = random.randint(-5, 30)
|
||||
condition = random.choice(conditions)
|
||||
day = "today" if i == 0 else f"in {i} day{'s' if i > 1 else ''}"
|
||||
forecast.append(f"{day.capitalize()}: {condition}, {temp}°C")
|
||||
return f"Weather forecast for {location}:\n" + "\n".join(forecast)
|
||||
|
||||
return [get_weather, get_forecast]
|
||||
|
||||
|
||||
def create_trivia_tools():
|
||||
"""Create trivia and fact tools."""
|
||||
|
||||
def get_random_fact() -> str:
|
||||
"""Get a random interesting fact."""
|
||||
facts = [
|
||||
"Honey never spoils. Archaeologists have found edible honey in ancient Egyptian tombs.",
|
||||
"A group of flamingos is called a 'flamboyance'.",
|
||||
"Octopuses have three hearts and blue blood.",
|
||||
"The Great Wall of China isn't visible from space with the naked eye.",
|
||||
"Bananas are berries, but strawberries aren't.",
|
||||
"Wombat poop is cube-shaped.",
|
||||
"A shrimp's heart is in its head.",
|
||||
"It's impossible to hum while holding your nose.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
def get_science_fact() -> str:
|
||||
"""Get a random science fact."""
|
||||
facts = [
|
||||
"The speed of light in a vacuum is approximately 299,792,458 meters per second.",
|
||||
"DNA stands for Deoxyribonucleic Acid.",
|
||||
"The human brain uses about 20% of the body's total energy.",
|
||||
"There are more possible games of chess than atoms in the observable universe.",
|
||||
"A single bolt of lightning contains enough energy to toast 100,000 slices of bread.",
|
||||
]
|
||||
return random.choice(facts)
|
||||
|
||||
return [get_random_fact, get_science_fact]
|
||||
|
||||
|
||||
def create_math_tools():
|
||||
"""Create math calculation tools."""
|
||||
|
||||
def calculate(expression: str) -> str:
|
||||
"""Safely calculate a mathematical expression."""
|
||||
try:
|
||||
# Only allow basic math operations for safety
|
||||
allowed_chars = set("0123456789+-*/.() ")
|
||||
if not all(c in allowed_chars for c in expression):
|
||||
return "Sorry, I can only calculate basic math expressions with +, -, *, /, and parentheses."
|
||||
|
||||
result = eval(expression)
|
||||
return f"{expression} = {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating '{expression}': {str(e)}"
|
||||
|
||||
def generate_math_problem() -> str:
|
||||
"""Generate a random math problem."""
|
||||
operations = ["+", "-", "*"]
|
||||
a = random.randint(1, 20)
|
||||
b = random.randint(1, 20)
|
||||
op = random.choice(operations)
|
||||
|
||||
if op == "+":
|
||||
answer = a + b
|
||||
elif op == "-":
|
||||
answer = a - b
|
||||
else: # multiplication
|
||||
answer = a * b
|
||||
|
||||
return f"Here's a math problem for you: {a} {op} {b} = ?"
|
||||
|
||||
return [calculate, generate_math_problem]
|
||||
|
||||
|
||||
async def create_specialist_agents():
|
||||
"""Create specialized agents for different domains."""
|
||||
|
||||
# Weather specialist agent
|
||||
weather_agent = OpenAIAgentService(
|
||||
name="Weather Specialist",
|
||||
instructions="""You are a weather specialist. You provide detailed weather information,
|
||||
forecasts, and weather-related advice. Use your tools to get accurate weather data.
|
||||
Be informative and helpful about weather conditions and what they might mean for
|
||||
outdoor activities.""",
|
||||
tools=create_weather_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Trivia specialist agent
|
||||
trivia_agent = OpenAIAgentService(
|
||||
name="Trivia Master",
|
||||
instructions="""You are a trivia and facts specialist. You love sharing interesting
|
||||
facts, trivia, and educational content. Use your tools to provide fascinating
|
||||
information and engage users with fun facts. Make learning enjoyable!""",
|
||||
tools=create_trivia_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Math specialist agent
|
||||
math_agent = OpenAIAgentService(
|
||||
name="Math Helper",
|
||||
instructions="""You are a mathematics specialist. You help with calculations,
|
||||
math problems, and mathematical concepts. Use your tools to solve problems
|
||||
and generate practice questions. Make math accessible and fun!""",
|
||||
tools=create_math_tools(),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return weather_agent, trivia_agent, math_agent
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting OpenAI Agent bot with handoffs")
|
||||
|
||||
# Set up STT for speech recognition
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
model="nova-2",
|
||||
)
|
||||
|
||||
# Set up TTS for voice output
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY", ""),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Create specialist agents
|
||||
weather_agent, trivia_agent, math_agent = await create_specialist_agents()
|
||||
|
||||
# Create the main triage agent that can hand off to specialists
|
||||
triage_agent = OpenAIAgentService(
|
||||
name="Assistant Coordinator",
|
||||
instructions="""You are a helpful assistant coordinator. Your role is to understand
|
||||
what the user needs and direct them to the right specialist:
|
||||
|
||||
- For weather questions, forecasts, or outdoor activity planning -> Weather Specialist
|
||||
- For interesting facts, trivia, or educational content -> Trivia Master
|
||||
- For calculations, math problems, or mathematical help -> Math Helper
|
||||
|
||||
If the request doesn't clearly fit a specialist, you can handle general conversation
|
||||
yourself. Always be friendly and explain when you're connecting them to a specialist.""",
|
||||
handoffs=[weather_agent.agent, trivia_agent.agent, math_agent.agent], # type: ignore
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Set up conversation context with initial system message
|
||||
messages: List[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant coordinator with access to weather information, trivia, and math tools. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = triage_agent.create_context_aggregator(context)
|
||||
|
||||
# Create the processing pipeline with context aggregators
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # Speech to text
|
||||
context_aggregator.user(), # User responses
|
||||
triage_agent, # OpenAI Agent processing
|
||||
tts, # Text to speech
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Send an initial greeting when client connects
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected, sending greeting")
|
||||
# Kick off the conversation by adding system message and running LLM
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please introduce yourself to the user as an AI assistant coordinator who works with specialists for weather, trivia, and math topics.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -34,7 +34,7 @@ dependencies = [
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai>=1.74.0,<=1.99.1",
|
||||
"openai>=1.74.0,<2.0.0",
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
"wait_for2>=0.4.1; python_version<'3.12'",
|
||||
@@ -74,7 +74,7 @@ langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-ope
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "websockets>=13.1,<15.0" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mcp = [ "mcp[cli]>=1.11.0,<2.0.0" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
@@ -83,7 +83,8 @@ nim = []
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "websockets>=13.1,<15.0" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openai-agent = [ "openai-agents~=0.3.0" ]
|
||||
# openpipe = [ "openpipe~=4.50.0" ] # Temporarily disabled due to openai version conflict
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "websockets>=13.1,<15.0" ]
|
||||
@@ -95,8 +96,9 @@ sambanova = []
|
||||
sarvam = [ "websockets>=13.1,<15.0" ]
|
||||
sentry = [ "sentry-sdk~=2.23.1" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3", "onnxruntime>=1.20.1, <2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime~=1.20.1" ]
|
||||
silero = [ "onnxruntime>=1.20.1, <2" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soniox = [ "websockets>=13.1,<15.0" ]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
@@ -154,6 +156,7 @@ where = ["src"]
|
||||
"src/pipecat/audio/dtmf/dtmf-star.wav",
|
||||
]
|
||||
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
|
||||
@@ -135,6 +135,25 @@ TESTS_14 = [
|
||||
("14r-function-calling-aws.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14v-function-calling-openai.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14w-function-calling-mistral.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14x-function-calling-universal-context.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
(
|
||||
"14y-function-calling-google-universal-context.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
(
|
||||
"14z-function-calling-anthropic-universal-context.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
(
|
||||
"14aa-function-calling-aws-universal-context.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
# Currently not working.
|
||||
# ("14c-function-calling-together.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
# ("14l-function-calling-deepseek.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
@@ -148,6 +167,7 @@ TESTS_15 = [
|
||||
TESTS_19 = [
|
||||
("19-openai-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19a-azure-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19b-openai-realtime-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19b-openai-realtime-beta-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
]
|
||||
|
||||
|
||||
@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
# Should be a TypedDict
|
||||
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
|
||||
@@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
Subclasses must implement provider-specific conversion logic.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
|
||||
|
||||
Returns:
|
||||
The identifier string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
@@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
|
||||
|
||||
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
"""Get messages from the LLM context, including standard and LLM-specific messages.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages including standard and LLM-specific messages.
|
||||
"""
|
||||
return context.get_messages(self.id_for_llm_specific_messages)
|
||||
|
||||
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
|
||||
"""Convert tools from standard format to provider format.
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
|
||||
from anthropic import NOT_GIVEN, NotGiven
|
||||
from anthropic.types.message_param import MessageParam
|
||||
@@ -28,10 +28,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
|
||||
|
||||
class AnthropicLLMInvocationParams(TypedDict):
|
||||
"""Context-based parameters for invoking Anthropic's LLM API.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
"""
|
||||
"""Context-based parameters for invoking Anthropic's LLM API."""
|
||||
|
||||
system: str | NotGiven
|
||||
messages: List[MessageParam]
|
||||
@@ -45,13 +42,16 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
to the specific format required by Anthropic's Claude models for function calling.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
|
||||
return "anthropic"
|
||||
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, enable_prompt_caching: bool
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
"""Get Anthropic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
enable_prompt_caching: Whether prompt caching should be enabled.
|
||||
@@ -59,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": (
|
||||
@@ -76,8 +76,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
@@ -85,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Anthropic.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -99,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
specific function-calling format, enabling tool use with Nova Sonic models.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -6,21 +6,33 @@
|
||||
|
||||
"""AWS Bedrock LLM adapter for Pipecat."""
|
||||
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class AWSBedrockLLMInvocationParams(TypedDict):
|
||||
"""Context-based parameters for invoking AWS Bedrock's LLM API.
|
||||
"""Context-based parameters for invoking AWS Bedrock's LLM API."""
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
|
||||
"""
|
||||
|
||||
pass
|
||||
system: Optional[List[dict[str, Any]]] # [{"text": "system message"}]
|
||||
messages: List[dict[str, Any]]
|
||||
tools: List[dict[str, Any]]
|
||||
tool_choice: LLMContextToolChoice
|
||||
|
||||
|
||||
class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
@@ -30,33 +42,244 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
into AWS Bedrock's expected tool format for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
|
||||
return "aws"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
|
||||
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Bedrock's LLM API.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": messages.messages,
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools) or [],
|
||||
# To avoid refactoring in AWSBedrockLLMService, we just pass through tool_choice.
|
||||
# Eventually (when we don't have to maintain the non-LLMContext code path) we should do
|
||||
# the conversion to Bedrock's expected format here rather than in AWSBedrockLLMService.
|
||||
"tool_choice": context.tool_choice,
|
||||
}
|
||||
|
||||
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
|
||||
"""Get messages from a universal LLM context in a format ready for logging about AWS Bedrock.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging about AWS Bedrock.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
for message in messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("image"):
|
||||
item["image"]["source"]["bytes"] = "..."
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
messages: List[dict[str, Any]]
|
||||
system: Optional[str]
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
system = None
|
||||
messages = []
|
||||
|
||||
# first, map messages using self._from_universal_context_message(m)
|
||||
try:
|
||||
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our messages list
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages[0]["content"]
|
||||
messages.pop(0)
|
||||
|
||||
# Convert any subsequent "system"-role messages to "user"-role
|
||||
# messages, as AWS Bedrock doesn't support system input messages.
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
message["role"] = "user"
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(messages) - 1:
|
||||
current_message = messages[i]
|
||||
next_message = messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system=system)
|
||||
|
||||
def _from_universal_context_message(self, message: LLMContextMessage) -> dict[str, Any]:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return copy.deepcopy(message.message)
|
||||
return self._from_standard_message(message)
|
||||
|
||||
def _from_standard_message(self, message: LLMStandardMessage) -> dict[str, Any]:
|
||||
"""Convert standard format message to AWS Bedrock format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Message in AWS Bedrock format.
|
||||
|
||||
Examples:
|
||||
Standard format input::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
AWS Bedrock format output::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "123",
|
||||
"name": "search",
|
||||
"input": {"q": "test"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
message = copy.deepcopy(message)
|
||||
if message["role"] == "tool":
|
||||
# Try to parse the content as JSON if it looks like JSON
|
||||
try:
|
||||
if message["content"].strip().startswith("{") and message[
|
||||
"content"
|
||||
].strip().endswith("}"):
|
||||
content_json = json.loads(message["content"])
|
||||
tool_result_content = [{"json": content_json}]
|
||||
else:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
except:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": tool_result_content,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
|
||||
# Handle text content
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
return {"role": message["role"], "content": [{"text": "(empty)"}]}
|
||||
else:
|
||||
return {"role": message["role"], "content": [{"text": content}]}
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
# fix empty text
|
||||
if item.get("type", "") == "text":
|
||||
text_content = item["text"] if item["text"] != "" else "(empty)"
|
||||
new_content.append({"text": text_content})
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
new_item = {
|
||||
"image": {
|
||||
"format": "jpeg",
|
||||
"source": {
|
||||
"bytes": base64.b64decode(item["image_url"]["url"].split(",")[1])
|
||||
},
|
||||
}
|
||||
}
|
||||
new_content.append(new_item)
|
||||
# In the case where there's a single image in the list (like what
|
||||
# would result from a UserImageRawFrame), ensure that the image
|
||||
# comes before text
|
||||
image_indices = [i for i, item in enumerate(new_content) if "image" in item]
|
||||
text_indices = [i for i, item in enumerate(new_content) if "text" in item]
|
||||
if len(image_indices) == 1 and text_indices:
|
||||
img_idx = image_indices[0]
|
||||
first_txt_idx = text_indices[0]
|
||||
if img_idx > first_txt_idx:
|
||||
# Move image before the first text
|
||||
image_item = new_content.pop(img_idx)
|
||||
new_content.insert(first_txt_idx, image_item)
|
||||
return {"role": message["role"], "content": new_content}
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
|
||||
@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Google."""
|
||||
return "google"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Gemini.
|
||||
"""
|
||||
# Get messages in Gemini's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
messages_for_logging.append(obj)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("google")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
@@ -47,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
|
||||
return "openai"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
|
||||
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -57,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
return {
|
||||
"messages": self._from_universal_context_messages(self._get_messages(context)),
|
||||
"messages": self._from_universal_context_messages(self.get_messages(context)),
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
"tool_choice": context.tool_choice,
|
||||
@@ -91,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
List of messages in a format ready for logging about OpenAI.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self._get_messages(context):
|
||||
for message in self.get_messages(context):
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
@@ -104,14 +110,18 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("openai")
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, messages: List[LLMContextMessage]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
# Just a pass-through: messages are already the right type
|
||||
return messages
|
||||
result = []
|
||||
for message in messages:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
# Extract the actual message content from LLMSpecificMessage
|
||||
result.append(message.message)
|
||||
else:
|
||||
# Standard message, pass through unchanged
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
def _from_standard_tool_choice(
|
||||
self, tool_choice: LLMContextToolChoice | NotGiven
|
||||
|
||||
@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
OpenAI's Realtime API for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
|
||||
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
0
src/pipecat/audio/turn/smart_turn/data/__init__.py
Normal file
0
src/pipecat/audio/turn/smart_turn/data/__init__.py
Normal file
BIN
src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx
Normal file
BIN
src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx
Normal file
Binary file not shown.
124
src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py
Normal file
124
src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Local turn analyzer for on-device ML inference using the smart-turn-v3 model.
|
||||
|
||||
This module provides a smart turn analyzer that uses an ONNX model for
|
||||
local end-of-turn detection without requiring network connectivity.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
from transformers import WhisperFeatureExtractor
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use LocalSmartTurnAnalyzerV3, you need to `pip install pipecat-ai[local-smart-turn-v3]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
"""Local turn analyzer using the smart-turn-v3 ONNX model.
|
||||
|
||||
Provides end-of-turn detection using locally-stored ONNX model,
|
||||
enabling offline operation without network dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self, *, smart_turn_model_path: Optional[str] = None, **kwargs):
|
||||
"""Initialize the local ONNX smart-turn-v3 analyzer.
|
||||
|
||||
Args:
|
||||
smart_turn_model_path: Path to the ONNX model file. If this is not
|
||||
set, the bundled smart-turn-v3.0 model will be used.
|
||||
**kwargs: Additional arguments passed to BaseSmartTurn.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
logger.debug("Loading Local Smart Turn v3 model...")
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.0.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
|
||||
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except BaseException:
|
||||
from importlib import resources as impresources
|
||||
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
smart_turn_model_path = f
|
||||
except BaseException:
|
||||
smart_turn_model_path = str(
|
||||
impresources.files(package_path).joinpath(model_name)
|
||||
)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3")
|
||||
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000):
|
||||
"""Truncate audio to last n seconds or pad with zeros to meet n seconds."""
|
||||
max_samples = n_seconds * sample_rate
|
||||
if len(audio_array) > max_samples:
|
||||
return audio_array[-max_samples:]
|
||||
elif len(audio_array) < max_samples:
|
||||
# Pad with zeros at the beginning
|
||||
padding = max_samples - len(audio_array)
|
||||
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||
return audio_array
|
||||
|
||||
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
||||
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
||||
|
||||
# Process audio using Whisper's feature extractor
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=8 * 16000,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
|
||||
# Convert to numpy and ensure correct shape for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32)
|
||||
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
|
||||
|
||||
# Run ONNX inference
|
||||
outputs = self._session.run(None, {"input_features": input_features})
|
||||
|
||||
# Extract probability (ONNX model returns sigmoid probabilities)
|
||||
probability = outputs[0][0].item()
|
||||
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
}
|
||||
@@ -21,7 +21,6 @@ from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -360,7 +359,7 @@ class ClassificationProcessor(FrameProcessor):
|
||||
await self._voicemail_notifier.notify() # Clear buffered TTS frames
|
||||
|
||||
# Interrupt the current pipeline to stop any ongoing processing
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
# Set the voicemail event to trigger the voicemail handler
|
||||
self._voicemail_event.clear()
|
||||
|
||||
@@ -788,43 +788,6 @@ class FatalErrorFrame(ErrorFrame):
|
||||
fatal: bool = field(default=True, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndTaskFrame(SystemFrame):
|
||||
"""Frame to request graceful pipeline task closure.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
closed nicely (flushing all the queued frames) by pushing an EndFrame
|
||||
downstream. This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CancelTaskFrame(SystemFrame):
|
||||
"""Frame to request immediate pipeline task cancellation.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
stopped immediately by pushing a CancelFrame downstream. This frame
|
||||
should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopTaskFrame(SystemFrame):
|
||||
"""Frame to request pipeline task stop while keeping processors running.
|
||||
|
||||
This is used to notify the pipeline task that it should be stopped as
|
||||
soon as possible (flushing all the queued frames) but that the pipeline
|
||||
processors should be kept in a running state. This frame should be pushed
|
||||
upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameProcessorPauseUrgentFrame(SystemFrame):
|
||||
"""Frame to pause frame processing immediately.
|
||||
@@ -857,7 +820,7 @@ class FrameProcessorResumeUrgentFrame(SystemFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartInterruptionFrame(SystemFrame):
|
||||
class InterruptionFrame(SystemFrame):
|
||||
"""Frame indicating user started speaking (interruption detected).
|
||||
|
||||
Emitted by the BaseInputTransport to indicate that a user has started
|
||||
@@ -869,6 +832,34 @@ class StartInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartInterruptionFrame(InterruptionFrame):
|
||||
"""Frame indicating user started speaking (interruption detected).
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
This frame is deprecated and will be removed in a future version.
|
||||
Instead, use `InterruptionFrame`.
|
||||
|
||||
Emitted by the BaseInputTransport to indicate that a user has started
|
||||
speaking (i.e. is interrupting). This is similar to
|
||||
UserStartedSpeakingFrame except that it should be pushed concurrently
|
||||
with other frames (so the order is not guaranteed).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"StartInterruptionFrame is deprecated and will be removed in a future version. "
|
||||
"Instead, use InterruptionFrame.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame indicating user has started speaking.
|
||||
@@ -944,20 +935,6 @@ class VADUserStoppedSpeakingFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstreams. It results in the BaseInputTransport
|
||||
starting an interruption by pushing a StartInterruptionFrame downstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame indicating the bot started speaking.
|
||||
@@ -1289,6 +1266,103 @@ class SpeechControlParamsFrame(SystemFrame):
|
||||
turn_params: Optional[SmartTurnParams] = None
|
||||
|
||||
|
||||
#
|
||||
# Task frames
|
||||
#
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskFrame(SystemFrame):
|
||||
"""Base frame for task frames.
|
||||
|
||||
This is a base class for frames that are meant to be sent and handled
|
||||
upstream by the pipeline task. This might result in a corresponding frame
|
||||
sent downstream (e.g. `InterruptionTaskFrame` / `InterruptionFrame` or
|
||||
`EndTaskFrame` / `EndFrame`).
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndTaskFrame(TaskFrame):
|
||||
"""Frame to request graceful pipeline task closure.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
closed nicely (flushing all the queued frames) by pushing an EndFrame
|
||||
downstream. This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CancelTaskFrame(TaskFrame):
|
||||
"""Frame to request immediate pipeline task cancellation.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
stopped immediately by pushing a CancelFrame downstream. This frame
|
||||
should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopTaskFrame(TaskFrame):
|
||||
"""Frame to request pipeline task stop while keeping processors running.
|
||||
|
||||
This is used to notify the pipeline task that it should be stopped as
|
||||
soon as possible (flushing all the queued frames) but that the pipeline
|
||||
processors should be kept in a running state. This frame should be pushed
|
||||
upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterruptionTaskFrame(TaskFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(InterruptionTaskFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
This frame is deprecated and will be removed in a future version.
|
||||
Instead, use `InterruptionTaskFrame`.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"BotInterruptionFrame is deprecated and will be removed in a future version. "
|
||||
"Instead, use InterruptionTaskFrame.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Control frames
|
||||
#
|
||||
|
||||
@@ -54,7 +54,7 @@ class DebugLogObserver(BaseObserver):
|
||||
|
||||
Log frames with specific source/destination filters::
|
||||
|
||||
from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
|
||||
from pipecat.frames.frames import InterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.services.stt_service import STTService
|
||||
@@ -62,8 +62,8 @@ class DebugLogObserver(BaseObserver):
|
||||
observers=[
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
# Only log StartInterruptionFrame when source is BaseOutputTransport
|
||||
StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
# Only log InterruptionFrame when source is BaseOutputTransport
|
||||
InterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
# Only log UserStartedSpeakingFrame when destination is STTService
|
||||
UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION),
|
||||
# Log LLMTextFrame regardless of source or destination type
|
||||
|
||||
@@ -32,6 +32,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
StopFrame,
|
||||
@@ -113,9 +115,28 @@ class PipelineTask(BasePipelineTask):
|
||||
- on_frame_reached_downstream: Called when downstream frames reach the sink
|
||||
- on_idle_timeout: Called when pipeline is idle beyond timeout threshold
|
||||
- on_pipeline_started: Called when pipeline starts with StartFrame
|
||||
- on_pipeline_stopped: Called when pipeline stops with StopFrame
|
||||
- on_pipeline_ended: Called when pipeline ends with EndFrame
|
||||
- on_pipeline_cancelled: Called when pipeline is cancelled
|
||||
- on_pipeline_stopped: [deprecated] Called when pipeline stops with StopFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_ended: [deprecated] Called when pipeline ends with EndFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_cancelled: [deprecated] Called when pipeline is cancelled with CancelFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_finished: Called after the pipeline has reached any terminal state.
|
||||
This includes:
|
||||
- StopFrame: pipeline was stopped (processors keep connections open)
|
||||
- EndFrame: pipeline ended normally
|
||||
- CancelFrame: pipeline was cancelled
|
||||
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
|
||||
the frame if they need to handle specific cases.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -126,6 +147,10 @@ class PipelineTask(BasePipelineTask):
|
||||
@task.event_handler("on_idle_timeout")
|
||||
async def on_pipeline_idle_timeout(task):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -262,6 +287,7 @@ class PipelineTask(BasePipelineTask):
|
||||
self._register_event_handler("on_pipeline_stopped")
|
||||
self._register_event_handler("on_pipeline_ended")
|
||||
self._register_event_handler("on_pipeline_cancelled")
|
||||
self._register_event_handler("on_pipeline_finished")
|
||||
|
||||
@property
|
||||
def params(self) -> PipelineParams:
|
||||
@@ -290,6 +316,27 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
return self._turn_trace_observer
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
"""Decorator for registering event handlers.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to handle.
|
||||
|
||||
Returns:
|
||||
The decorator function that registers the handler.
|
||||
"""
|
||||
if event_name in ["on_pipeline_stopped", "on_pipeline_ended", "on_pipeline_cancelled"]:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
f"Event '{event_name}' is deprecated, use 'on_pipeline_finished' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return super().event_handler(event_name)
|
||||
|
||||
def add_observer(self, observer: BaseObserver):
|
||||
"""Add an observer to monitor pipeline execution.
|
||||
|
||||
@@ -532,6 +579,7 @@ class PipelineTask(BasePipelineTask):
|
||||
)
|
||||
finally:
|
||||
await self._call_event_handler("on_pipeline_cancelled", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
|
||||
logger.debug(f"{self}: Closing. Waiting for {frame} to reach the end of the pipeline...")
|
||||
|
||||
@@ -627,13 +675,23 @@ class PipelineTask(BasePipelineTask):
|
||||
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
# Tell the task we should end nicely.
|
||||
logger.debug(f"{self}: received end task frame {frame}")
|
||||
await self.queue_frame(EndFrame())
|
||||
elif isinstance(frame, CancelTaskFrame):
|
||||
# Tell the task we should end right away.
|
||||
logger.debug(f"{self}: received cancel task frame {frame}")
|
||||
await self.queue_frame(CancelFrame())
|
||||
elif isinstance(frame, StopTaskFrame):
|
||||
# Tell the task we should stop nicely.
|
||||
logger.debug(f"{self}: received stop task frame {frame}")
|
||||
await self.queue_frame(StopFrame())
|
||||
elif isinstance(frame, InterruptionTaskFrame):
|
||||
# Tell the task we should interrupt the pipeline. Note that we are
|
||||
# bypassing the push queue and directly queue into the
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
if frame.fatal:
|
||||
logger.error(f"A fatal error occurred: {frame}")
|
||||
@@ -642,7 +700,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# Tell the task we should stop.
|
||||
await self.queue_frame(StopTaskFrame())
|
||||
else:
|
||||
logger.warning(f"Something went wrong: {frame}")
|
||||
logger.warning(f"{self}: Something went wrong: {frame}")
|
||||
|
||||
async def _sink_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames coming downstream from the pipeline.
|
||||
@@ -669,9 +727,11 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_start_event.set()
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._call_event_handler("on_pipeline_ended", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, StopFrame):
|
||||
await self._call_event_handler("on_pipeline_stopped", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import Optional
|
||||
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -24,7 +23,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -105,7 +104,7 @@ class DTMFAggregator(FrameProcessor):
|
||||
|
||||
# For first digit, schedule interruption.
|
||||
if is_first_digit:
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
# Check for immediate flush conditions
|
||||
if frame.button == self._termination_digit:
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -36,6 +35,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
@@ -48,7 +48,6 @@ from pipecat.frames.frames import (
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -138,7 +137,7 @@ class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._call_event_handler("on_completion", self._aggregation, False)
|
||||
self._aggregation = ""
|
||||
self._started = False
|
||||
@@ -532,9 +531,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
@@ -838,7 +837,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
@@ -904,7 +903,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self.push_aggregation()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
@@ -13,7 +13,6 @@ LLM processing, and text-to-speech components in conversational AI pipelines.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
@@ -23,7 +22,6 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -37,6 +35,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -48,7 +47,6 @@ from pipecat.frames.frames import (
|
||||
LLMSetToolsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -311,9 +309,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
@@ -579,7 +577,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
@@ -645,7 +643,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self._push_aggregation()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
@@ -137,12 +137,12 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
return self._num_channels
|
||||
|
||||
def has_audio(self) -> bool:
|
||||
"""Check if both user and bot audio buffers contain data.
|
||||
"""Check if either user or bot audio buffers contain data.
|
||||
|
||||
Returns:
|
||||
True if both buffers contain audio data.
|
||||
True if either buffer contains audio data.
|
||||
"""
|
||||
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
|
||||
return self._buffer_has_audio(self._user_audio_buffer) or self._buffer_has_audio(
|
||||
self._bot_audio_buffer
|
||||
)
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -204,7 +204,7 @@ class STTMuteFilter(FrameProcessor):
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
|
||||
@@ -28,8 +28,9 @@ from pipecat.frames.frames import (
|
||||
FrameProcessorPauseUrgentFrame,
|
||||
FrameProcessorResumeFrame,
|
||||
FrameProcessorResumeUrgentFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
|
||||
@@ -219,6 +220,9 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_event: Optional[asyncio.Event] = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""Get the unique identifier for this processor.
|
||||
@@ -542,6 +546,14 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption we will bypass all queued system
|
||||
# frames and we will process the frame right away. This is because a
|
||||
# previous system frame might be waiting for the interruption frame and
|
||||
# it's blocking the input task.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
|
||||
if self._enable_direct_mode:
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
@@ -588,7 +600,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.__start(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._start_interruption()
|
||||
await self.stop_all_metrics()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
@@ -620,6 +632,32 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the pipeline
|
||||
task and waits to receive the corresponding `InterruptionFrame`. When
|
||||
the function finishes it is guaranteed that the `InterruptionFrame` has
|
||||
been pushed downstream.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
# Wait for an `InterruptionFrame` to come to this processor and be
|
||||
# pushed. Take a look at `push_frame()` to see how we first push the
|
||||
# `InterruptionFrame` and then we set the event in order to maintain
|
||||
# frame ordering.
|
||||
await self._wait_interruption_event.wait()
|
||||
|
||||
# Clean the event.
|
||||
self._wait_interruption_event.clear()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
|
||||
async def __start(self, frame: StartFrame):
|
||||
"""Handle the start frame to initialize processor state.
|
||||
|
||||
@@ -669,20 +707,22 @@ class FrameProcessor(BaseObject):
|
||||
async def _start_interruption(self):
|
||||
"""Start handling an interruption by cancelling current tasks."""
|
||||
try:
|
||||
# Cancel the process task. This will stop processing queued frames.
|
||||
await self.__cancel_process_task()
|
||||
if self._wait_for_interruption:
|
||||
# If we get here we know the process task was just waiting for
|
||||
# an interruption (push_interruption_task_frame_and_wait()), so
|
||||
# we can't cancel the task because it might still need to do
|
||||
# more things (e.g. pushing a frame after the
|
||||
# interruption). Instead we just drain the queue because this is
|
||||
# an interruption.
|
||||
self.__reset_process_task()
|
||||
else:
|
||||
# Cancel and re-create the process task including the queue.
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
# Create a new process queue and task.
|
||||
self.__create_process_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
"""Stop handling an interruption."""
|
||||
# Nothing to do right now.
|
||||
pass
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Internal method to push frames to adjacent processors.
|
||||
|
||||
@@ -764,6 +804,17 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
|
||||
|
||||
def __reset_process_task(self):
|
||||
"""Reset non-system frame processing task."""
|
||||
if self._enable_direct_mode:
|
||||
return
|
||||
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = asyncio.Event()
|
||||
while not self.__process_queue.empty():
|
||||
self.__process_queue.get_nowait()
|
||||
self.__process_queue.task_done()
|
||||
|
||||
async def __cancel_process_task(self):
|
||||
"""Cancel the non-system frame processing task."""
|
||||
if self.__process_frame_task:
|
||||
|
||||
@@ -30,7 +30,6 @@ from loguru import logger
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -1206,7 +1205,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def interrupt_bot(self):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
async def send_server_message(self, data: Any):
|
||||
"""Send a server message to the client."""
|
||||
|
||||
@@ -19,7 +19,7 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
@@ -86,7 +86,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
transcript messages. Utterances are completed when:
|
||||
|
||||
- The bot stops speaking (BotStoppedSpeakingFrame)
|
||||
- The bot is interrupted (StartInterruptionFrame)
|
||||
- The bot is interrupted (InterruptionFrame)
|
||||
- The pipeline ends (EndFrame)
|
||||
"""
|
||||
|
||||
@@ -185,7 +185,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
|
||||
- TTSTextFrame: Aggregates text for current utterance
|
||||
- BotStoppedSpeakingFrame: Completes current utterance
|
||||
- StartInterruptionFrame: Completes current utterance due to interruption
|
||||
- InterruptionFrame: Completes current utterance due to interruption
|
||||
- EndFrame: Completes current utterance at pipeline end
|
||||
- CancelFrame: Completes current utterance due to cancellation
|
||||
|
||||
@@ -195,7 +195,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (StartInterruptionFrame, CancelFrame)):
|
||||
if isinstance(frame, (InterruptionFrame, CancelFrame)):
|
||||
# Push frame first otherwise our emitted transcription update frame
|
||||
# might get cleaned up.
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -51,9 +51,11 @@ class WebSocketRunnerArguments(RunnerArguments):
|
||||
|
||||
Parameters:
|
||||
websocket: WebSocket connection for audio streaming
|
||||
body: Additional request data
|
||||
"""
|
||||
|
||||
websocket: WebSocket
|
||||
body: Optional[Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -99,16 +99,35 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
tuple: (transport_type: str, call_data: dict)
|
||||
|
||||
call_data contains provider-specific fields:
|
||||
- Twilio: {"stream_id": str, "call_id": str}
|
||||
- Telnyx: {"stream_id": str, "call_control_id": str, "outbound_encoding": str}
|
||||
- Plivo: {"stream_id": str, "call_id": str}
|
||||
- Exotel: {"stream_id": str, "call_id": str, "account_sid": str}
|
||||
- Twilio: {
|
||||
"stream_id": str,
|
||||
"call_id": str,
|
||||
"body": dict
|
||||
}
|
||||
- Telnyx: {
|
||||
"stream_id": str,
|
||||
"call_control_id": str,
|
||||
"outbound_encoding": str,
|
||||
"from": str,
|
||||
"to": str,
|
||||
}
|
||||
- Plivo: {
|
||||
"stream_id": str,
|
||||
"call_id": str,
|
||||
}
|
||||
- Exotel: {
|
||||
"stream_id": str,
|
||||
"call_id": str,
|
||||
"account_sid": str,
|
||||
"from": str,
|
||||
"to": str,
|
||||
}
|
||||
|
||||
Example usage::
|
||||
|
||||
transport_type, call_data = await parse_telephony_websocket(websocket)
|
||||
if transport_type == "telnyx":
|
||||
outbound_encoding = call_data["outbound_encoding"]
|
||||
if transport_type == "twilio":
|
||||
user_id = call_data["body"]["user_id"]
|
||||
"""
|
||||
# Read first two messages
|
||||
start_data = websocket.iter_text()
|
||||
@@ -151,9 +170,12 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
# Extract provider-specific data
|
||||
if transport_type == "twilio":
|
||||
start_data = call_data_raw.get("start", {})
|
||||
body_data = start_data.get("customParameters", {})
|
||||
call_data = {
|
||||
"stream_id": start_data.get("streamSid"),
|
||||
"call_id": start_data.get("callSid"),
|
||||
# All custom parameters
|
||||
"body": body_data,
|
||||
}
|
||||
|
||||
elif transport_type == "telnyx":
|
||||
@@ -163,6 +185,8 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
"outbound_encoding": call_data_raw.get("start", {})
|
||||
.get("media_format", {})
|
||||
.get("encoding"),
|
||||
"from": call_data_raw.get("start", {}).get("from", ""),
|
||||
"to": call_data_raw.get("start", {}).get("to", ""),
|
||||
}
|
||||
|
||||
elif transport_type == "plivo":
|
||||
@@ -178,6 +202,8 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
"stream_id": start_data.get("stream_sid"),
|
||||
"call_id": start_data.get("call_sid"),
|
||||
"account_sid": start_data.get("account_sid"),
|
||||
"from": start_data.get("from", ""),
|
||||
"to": start_data.get("to", ""),
|
||||
}
|
||||
|
||||
else:
|
||||
|
||||
@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -98,7 +98,7 @@ class ExotelFrameSerializer(FrameSerializer):
|
||||
Returns:
|
||||
Serialized data as string or bytes, or None if the frame isn't handled.
|
||||
"""
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clear", "streamSid": self._stream_sid}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -22,8 +22,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -122,7 +122,7 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clearAudio", "streamId": self._stream_id}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -29,8 +29,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
|
||||
@@ -137,7 +137,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clear"}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -22,8 +22,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -122,7 +122,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clear", "streamSid": self._stream_sid}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -275,7 +275,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def _receive_messages(self):
|
||||
|
||||
@@ -25,7 +25,10 @@ from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import (
|
||||
AWSBedrockLLMAdapter,
|
||||
AWSBedrockLLMInvocationParams,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
@@ -808,64 +811,55 @@ class AWSBedrockLLMService(LLMService):
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
try:
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
# Future code will be something like this:
|
||||
# adapter = self.get_llm_adapter()
|
||||
# params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
# messages = params["messages"]
|
||||
# system = params["system_instruction"] # [{"text": "system message"}]
|
||||
raise NotImplementedError(
|
||||
"Universal LLMContext is not yet supported for AWS Bedrock."
|
||||
)
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}]
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _create_converse_stream(self, client, request_params):
|
||||
@@ -940,8 +934,25 @@ class AWSBedrockLLMService(LLMService):
|
||||
}
|
||||
}
|
||||
|
||||
def _get_llm_invocation_params(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AWSBedrockLLMInvocationParams:
|
||||
# Universal LLMContext
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(context)
|
||||
return params
|
||||
|
||||
# AWS Bedrock-specific context
|
||||
return AWSBedrockLLMInvocationParams(
|
||||
system=getattr(context, "system", None),
|
||||
messages=context.messages,
|
||||
tools=context.tools or [],
|
||||
tool_choice=context.tool_choice,
|
||||
)
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: AWSBedrockLLMContext):
|
||||
async def _process_context(self, context: AWSBedrockLLMContext | LLMContext):
|
||||
# Usage tracking
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
@@ -958,6 +969,12 @@ class AWSBedrockLLMService(LLMService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params_from_context = self._get_llm_invocation_params(context)
|
||||
messages = params_from_context["messages"]
|
||||
system = params_from_context["system"]
|
||||
tools = params_from_context["tools"]
|
||||
tool_choice = params_from_context["tool_choice"]
|
||||
|
||||
# Set up inference config
|
||||
inference_config = {
|
||||
"maxTokens": self._settings["max_tokens"],
|
||||
@@ -968,19 +985,18 @@ class AWSBedrockLLMService(LLMService):
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": self.model_name,
|
||||
"messages": context.messages,
|
||||
"messages": messages,
|
||||
"inferenceConfig": inference_config,
|
||||
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
|
||||
}
|
||||
|
||||
# Add system message
|
||||
system = getattr(context, "system", None)
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
|
||||
# Check if messages contain tool use or tool result content blocks
|
||||
has_tool_content = False
|
||||
for message in context.messages:
|
||||
for message in messages:
|
||||
if isinstance(message.get("content"), list):
|
||||
for content_item in message["content"]:
|
||||
if "toolUse" in content_item or "toolResult" in content_item:
|
||||
@@ -990,7 +1006,6 @@ class AWSBedrockLLMService(LLMService):
|
||||
break
|
||||
|
||||
# Handle tools: use current tools, or no-op if tool content exists but no current tools
|
||||
tools = context.tools or []
|
||||
if has_tool_content and not tools:
|
||||
tools = [self._create_no_op_tool()]
|
||||
using_noop_tool = True
|
||||
@@ -999,17 +1014,15 @@ class AWSBedrockLLMService(LLMService):
|
||||
tool_config = {"tools": tools}
|
||||
|
||||
# Only add tool_choice if we have real tools (not just no-op)
|
||||
if not using_noop_tool and context.tool_choice:
|
||||
if context.tool_choice == "auto":
|
||||
if not using_noop_tool and tool_choice:
|
||||
if tool_choice == "auto":
|
||||
tool_config["toolChoice"] = {"auto": {}}
|
||||
elif context.tool_choice == "none":
|
||||
elif tool_choice == "none":
|
||||
# Skip adding toolChoice for "none"
|
||||
pass
|
||||
elif (
|
||||
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
|
||||
):
|
||||
elif isinstance(tool_choice, dict) and "function" in tool_choice:
|
||||
tool_config["toolChoice"] = {
|
||||
"tool": {"name": context.tool_choice["function"]["name"]}
|
||||
"tool": {"name": tool_choice["function"]["name"]}
|
||||
}
|
||||
|
||||
request_params["toolConfig"] = tool_config
|
||||
@@ -1019,9 +1032,16 @@ class AWSBedrockLLMService(LLMService):
|
||||
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
|
||||
|
||||
# Log request params with messages redacted for logging
|
||||
log_params = dict(request_params)
|
||||
log_params["messages"] = context.get_messages_for_logging()
|
||||
logger.debug(f"Calling AWS Bedrock model with: {log_params}")
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
context_type_for_logging = "universal"
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
else:
|
||||
context_type_for_logging = "LLM-specific"
|
||||
messages_for_logging = context.get_messages_for_logging()
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from {context_type_for_logging} context [{system}] | {messages_for_logging}"
|
||||
)
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
@@ -1129,7 +1149,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AWSBedrockLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
|
||||
@@ -247,13 +247,14 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._ready_to_send_context = False
|
||||
self._handling_bot_stopped_speaking = False
|
||||
self._triggering_assistant_response = False
|
||||
self._assistant_response_trigger_audio: Optional[bytes] = (
|
||||
None # Not cleared on _disconnect()
|
||||
)
|
||||
self._disconnecting = False
|
||||
self._connected_time: Optional[float] = None
|
||||
self._wants_connection = False
|
||||
|
||||
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
@@ -1099,20 +1100,13 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
self._triggering_assistant_response = True
|
||||
|
||||
# Read audio bytes, if we don't already have them cached
|
||||
if not self._assistant_response_trigger_audio:
|
||||
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
# Send the trigger audio, if we're fully connected and set up
|
||||
if self._connected_time is not None:
|
||||
if self._connected_time:
|
||||
await self._send_assistant_response_trigger()
|
||||
|
||||
async def _send_assistant_response_trigger(self):
|
||||
if (
|
||||
not self._assistant_response_trigger_audio or self._connected_time is None
|
||||
): # should never happen
|
||||
if not self._connected_time:
|
||||
# should never happen
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -21,13 +21,13 @@ from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
@@ -306,7 +306,7 @@ class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
TextFrame,
|
||||
|
||||
@@ -19,6 +19,7 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
@@ -140,6 +141,7 @@ class AzureSTTService(STTService):
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
@@ -197,3 +199,15 @@ class AzureSTTService(STTService):
|
||||
self._handle_transcription(event.result.text, True, language), self.get_event_loop()
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
def _on_handle_recognizing(self, event):
|
||||
if event.result.reason == ResultReason.RecognizingSpeech and len(event.result.text) > 0:
|
||||
language = getattr(event.result, "language", None) or self._settings.get("language")
|
||||
frame = InterimTranscriptionFrame(
|
||||
event.result.text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=event,
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -371,7 +371,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
if self._context_id:
|
||||
|
||||
@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -460,7 +460,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
@@ -549,7 +549,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
@@ -558,7 +558,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
logger.trace(f"Closing context {self._context_id} due to interruption")
|
||||
try:
|
||||
# ElevenLabs requires that Pipecat manages the contexts and closes them
|
||||
# when they're not longer in use. Since a StartInterruptionFrame is pushed
|
||||
# when they're not longer in use. Since an InterruptionFrame is pushed
|
||||
# every time the user speaks, we'll use this as a trigger to close the context
|
||||
# and reset the state.
|
||||
# Note: We do not need to call remove_audio_context here, as the context is
|
||||
@@ -856,7 +856,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (StartInterruptionFrame, TTSStoppedFrame)):
|
||||
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
|
||||
# Reset timing on interruption or stop
|
||||
self._reset_state()
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -259,7 +259,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
self._request_id = None
|
||||
|
||||
@@ -33,6 +33,7 @@ from pipecat.frames.frames import (
|
||||
InputAudioRawFrame,
|
||||
InputImageRawFrame,
|
||||
InputTextRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -41,7 +42,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -752,7 +752,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
elif isinstance(frame, InputImageRawFrame):
|
||||
await self._send_user_video(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
|
||||
@@ -500,9 +500,11 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
Parameters:
|
||||
language: Language for synthesis. Defaults to English.
|
||||
speaking_rate: The speaking rate, in the range [0.25, 4.0].
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
speaking_rate: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -510,6 +512,7 @@ class GoogleTTSService(TTSService):
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
voice_id: str = "en-US-Chirp3-HD-Charon",
|
||||
voice_cloning_key: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -520,6 +523,7 @@ class GoogleTTSService(TTSService):
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
voice_id: Google TTS voice identifier (e.g., "en-US-Chirp3-HD-Charon").
|
||||
voice_cloning_key: The voice cloning key for Chirp 3 custom voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses default.
|
||||
params: Language configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
@@ -532,8 +536,10 @@ class GoogleTTSService(TTSService):
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"speaking_rate": params.speaking_rate,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._voice_cloning_key = voice_cloning_key
|
||||
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
|
||||
credentials, credentials_path
|
||||
)
|
||||
@@ -600,15 +606,24 @@ class GoogleTTSService(TTSService):
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
if self._voice_cloning_key:
|
||||
voice_clone_params = texttospeech_v1.VoiceCloneParams(
|
||||
voice_cloning_key=self._voice_cloning_key
|
||||
)
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], voice_clone=voice_clone_params
|
||||
)
|
||||
else:
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
|
||||
streaming_config = texttospeech_v1.StreamingSynthesizeConfig(
|
||||
voice=voice,
|
||||
streaming_audio_config=texttospeech_v1.StreamingAudioConfig(
|
||||
audio_encoding=texttospeech_v1.AudioEncoding.PCM,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
speaking_rate=self._settings["speaking_rate"],
|
||||
),
|
||||
)
|
||||
config_request = texttospeech_v1.StreamingSynthesizeRequest(
|
||||
|
||||
@@ -240,6 +240,7 @@ class HeyGenVideoService(AIService):
|
||||
# As soon as we receive actual audio, the base output transport will create a
|
||||
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -36,15 +36,15 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsStartedFrame,
|
||||
InterruptionFrame,
|
||||
LLMConfigureOutputFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -195,6 +195,17 @@ class LLMService(AIService):
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return self.get_llm_adapter().create_llm_specific_message(message)
|
||||
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
@@ -269,7 +280,7 @@ class LLMService(AIService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, LLMConfigureOutputFrame):
|
||||
self._skip_tts = frame.skip_tts
|
||||
@@ -286,7 +297,7 @@ class LLMService(AIService):
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _handle_interruptions(self, _: StartInterruptionFrame):
|
||||
async def _handle_interruptions(self, _: InterruptionFrame):
|
||||
for function_name, entry in self._functions.items():
|
||||
if entry.cancel_on_interruption:
|
||||
await self._cancel_function_call(function_name)
|
||||
|
||||
@@ -16,8 +16,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -180,7 +180,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def _connect(self):
|
||||
|
||||
@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -224,7 +224,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -64,6 +64,7 @@ class OpenAITTSService(TTSService):
|
||||
model: str = "gpt-4o-mini-tts",
|
||||
sample_rate: Optional[int] = None,
|
||||
instructions: Optional[str] = None,
|
||||
speed: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI TTS service.
|
||||
@@ -75,6 +76,7 @@ class OpenAITTSService(TTSService):
|
||||
model: TTS model to use. Defaults to "gpt-4o-mini-tts".
|
||||
sample_rate: Output audio sample rate in Hz. If None, uses OpenAI's default 24kHz.
|
||||
instructions: Optional instructions to guide voice synthesis behavior.
|
||||
speed: Voice speed control (0.25 to 4.0, default 1.0).
|
||||
**kwargs: Additional keyword arguments passed to TTSService.
|
||||
"""
|
||||
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
@@ -84,6 +86,7 @@ class OpenAITTSService(TTSService):
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._speed = speed
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice)
|
||||
self._instructions = instructions
|
||||
@@ -133,17 +136,22 @@ class OpenAITTSService(TTSService):
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Setup extra body parameters
|
||||
extra_body = {}
|
||||
# Setup API parameters
|
||||
create_params = {
|
||||
"input": text,
|
||||
"model": self.model_name,
|
||||
"voice": VALID_VOICES[self._voice_id],
|
||||
"response_format": "pcm",
|
||||
}
|
||||
|
||||
if self._instructions:
|
||||
extra_body["instructions"] = self._instructions
|
||||
create_params["instructions"] = self._instructions
|
||||
|
||||
if self._speed:
|
||||
create_params["speed"] = self._speed
|
||||
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text,
|
||||
model=self.model_name,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
extra_body=extra_body,
|
||||
**create_params
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
error = await r.text()
|
||||
|
||||
209
src/pipecat/services/openai_agent/README.md
Normal file
209
src/pipecat/services/openai_agent/README.md
Normal file
@@ -0,0 +1,209 @@
|
||||
# OpenAI Agents SDK Integration
|
||||
|
||||
This service integrates the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) with Pipecat, enabling powerful agentic workflows with features like:
|
||||
|
||||
- **Agent loops** with tool calling and response streaming
|
||||
- **Handoffs** between specialized agents
|
||||
- **Guardrails** for input/output validation
|
||||
- **Sessions** with automatic conversation history
|
||||
- **Built-in tracing** and monitoring
|
||||
|
||||
## Installation
|
||||
|
||||
Install the OpenAI Agents SDK dependency:
|
||||
|
||||
```bash
|
||||
pip install "pipecat-ai[openai-agent]"
|
||||
# or
|
||||
uv add "pipecat-ai[openai-agent]"
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from pipecat.services.openai_agent import OpenAIAgentService
|
||||
|
||||
# Create a simple agent
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Use in a pipeline
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
agent_service,
|
||||
tts,
|
||||
transport.output(),
|
||||
])
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Tool Integration
|
||||
|
||||
```python
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny, 22°C"
|
||||
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Weather Assistant",
|
||||
instructions="Help users with weather information.",
|
||||
tools=[get_weather],
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
```
|
||||
|
||||
### Agent Handoffs
|
||||
|
||||
```python
|
||||
# Create specialized agents
|
||||
weather_agent = OpenAIAgentService(
|
||||
name="Weather Specialist",
|
||||
instructions="Provide weather information and forecasts.",
|
||||
tools=[get_weather, get_forecast],
|
||||
)
|
||||
|
||||
trivia_agent = OpenAIAgentService(
|
||||
name="Trivia Master",
|
||||
instructions="Share interesting facts and trivia.",
|
||||
tools=[get_random_fact],
|
||||
)
|
||||
|
||||
# Create coordinator that can hand off to specialists
|
||||
coordinator = OpenAIAgentService(
|
||||
name="Coordinator",
|
||||
instructions="Route users to the right specialist.",
|
||||
handoffs=[weather_agent.agent, trivia_agent.agent],
|
||||
)
|
||||
```
|
||||
|
||||
### Guardrails
|
||||
|
||||
```python
|
||||
from agents import InputGuardrail, GuardrailFunctionOutput
|
||||
|
||||
async def content_filter(ctx, agent, input_data):
|
||||
# Check input for appropriate content
|
||||
if is_inappropriate(input_data):
|
||||
return GuardrailFunctionOutput(
|
||||
tripwire_triggered=True,
|
||||
output_info="Content not allowed"
|
||||
)
|
||||
return GuardrailFunctionOutput(tripwire_triggered=False)
|
||||
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Safe Assistant",
|
||||
instructions="You are a helpful and safe assistant.",
|
||||
input_guardrails=[InputGuardrail(guardrail_function=content_filter)],
|
||||
)
|
||||
```
|
||||
|
||||
### Session Management
|
||||
|
||||
```python
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Personal Assistant",
|
||||
instructions="Remember user preferences and context.",
|
||||
session_config={
|
||||
"user_id": "user_123",
|
||||
"memory_enabled": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Update session context dynamically
|
||||
agent_service.update_session_context({
|
||||
"user_preferences": {"language": "en", "style": "formal"}
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Basic Parameters
|
||||
|
||||
- `name`: Agent identifier for handoffs and tracing
|
||||
- `instructions`: System prompt defining agent behavior
|
||||
- `api_key`: OpenAI API key (or use `OPENAI_API_KEY` env var)
|
||||
- `streaming`: Enable real-time token streaming (default: True)
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
- `tools`: List of callable functions for the agent to use
|
||||
- `handoffs`: List of other agents this agent can transfer to
|
||||
- `input_guardrails`: Input validation and filtering
|
||||
- `output_guardrails`: Output validation and filtering
|
||||
- `model_config`: Model settings (model, temperature, etc.)
|
||||
- `session_config`: Session and memory configuration
|
||||
|
||||
### Model Configuration
|
||||
|
||||
```python
|
||||
agent_service = OpenAIAgentService(
|
||||
name="Precise Assistant",
|
||||
instructions="Provide accurate, concise responses.",
|
||||
model_config={
|
||||
"model": "gpt-4o",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 150,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the foundational examples:
|
||||
|
||||
- [`45-openai-agent-basic.py`](../examples/foundational/45-openai-agent-basic.py) - Basic agent with tools
|
||||
- [`46-openai-agent-handoffs.py`](../examples/foundational/46-openai-agent-handoffs.py) - Multi-agent system with handoffs
|
||||
|
||||
## Methods
|
||||
|
||||
### Core Methods
|
||||
|
||||
- `update_agent_config()` - Update instructions and model settings
|
||||
- `add_tool()` - Add new tools dynamically
|
||||
- `add_handoff_agent()` - Add handoff destinations
|
||||
- `get_session_context()` - Get current session state
|
||||
- `update_session_context()` - Update session variables
|
||||
|
||||
### Lifecycle Methods
|
||||
|
||||
Inherited from `AIService`:
|
||||
- `start()` - Initialize the agent
|
||||
- `stop()` - Clean up resources
|
||||
- `cancel()` - Cancel ongoing operations
|
||||
|
||||
## Integration with Pipecat
|
||||
|
||||
The service processes `TextFrame` inputs and generates:
|
||||
- `LLMFullResponseStartFrame` - Response beginning
|
||||
- `LLMTextFrame` - Streaming text tokens (if streaming enabled)
|
||||
- `LLMFullResponseEndFrame` - Response completion
|
||||
|
||||
This integrates seamlessly with Pipecat's conversation pipeline and context aggregators.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The service includes robust error handling for:
|
||||
- Missing API keys or SDK installation
|
||||
- Agent processing failures
|
||||
- Network connectivity issues
|
||||
- Malformed tool responses
|
||||
|
||||
Errors are emitted as `ErrorFrame` objects in the pipeline.
|
||||
|
||||
## Requirements
|
||||
|
||||
- OpenAI API key
|
||||
- `openai-agents` package
|
||||
- Python 3.10+
|
||||
|
||||
## Limitations
|
||||
|
||||
- Currently supports OpenAI models only (via Agents SDK)
|
||||
- Handoffs work within individual requests (no cross-request state)
|
||||
- Real-time voice features require additional setup
|
||||
11
src/pipecat/services/openai_agent/__init__.py
Normal file
11
src/pipecat/services/openai_agent/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Agents SDK service for Pipecat integration."""
|
||||
|
||||
from .agent_service import OpenAIAgentService
|
||||
|
||||
__all__ = ["OpenAIAgentService"]
|
||||
567
src/pipecat/services/openai_agent/agent_service.py
Normal file
567
src/pipecat/services/openai_agent/agent_service.py
Normal file
@@ -0,0 +1,567 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Agents SDK integration service.
|
||||
|
||||
Provides integration with the OpenAI Agents SDK for building AI applications
|
||||
within Pipecat pipelines. This service allows leveraging agent loops, handoffs,
|
||||
guardrails, sessions, and tools from the OpenAI Agents SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Union,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from agents import Agent, InputGuardrail, OutputGuardrail, Runner, Tool
|
||||
from agents.result import RunResult, RunResultStreaming
|
||||
from agents.stream_events import StreamEvent
|
||||
except ImportError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI Agents SDK, you need to `pip install openai-agents`. "
|
||||
"Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolLike(Protocol):
|
||||
"""Protocol for tool-like objects."""
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Tool call interface."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentLike(Protocol):
|
||||
"""Protocol for agent-like objects."""
|
||||
|
||||
name: str
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Agent call interface."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIAgentContextAggregatorPair:
|
||||
"""Pair of OpenAI Agent context aggregators for user and assistant messages.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
_user: "OpenAIAgentUserContextAggregator"
|
||||
_assistant: "OpenAIAgentAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIAgentUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAgentAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAIAgentService(AIService):
|
||||
"""OpenAI Agents SDK service for Pipecat.
|
||||
|
||||
Integrates the OpenAI Agents SDK with Pipecat's pipeline architecture,
|
||||
enabling advanced agentic workflows with features like handoffs, guardrails,
|
||||
sessions, and tools within real-time conversational AI applications.
|
||||
|
||||
The service processes text input frames and generates streaming responses
|
||||
using the agent's configured capabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Optional[Agent] = None,
|
||||
name: str = "Assistant",
|
||||
instructions: Union[str, Sequence[str]] = "You are a helpful assistant.",
|
||||
handoffs: Optional[Sequence[AgentLike]] = None,
|
||||
tools: Optional[Sequence[ToolLike]] = None,
|
||||
input_guardrails: Optional[Sequence[InputGuardrail]] = None,
|
||||
output_guardrails: Optional[Sequence[OutputGuardrail]] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None,
|
||||
session_config: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
streaming: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI Agent service.
|
||||
|
||||
Args:
|
||||
agent: Pre-configured Agent instance. If provided, other agent configuration
|
||||
parameters will be ignored.
|
||||
name: Name of the agent for identification and handoffs.
|
||||
instructions: System instructions that define the agent's behavior.
|
||||
handoffs: List of other agents this agent can hand off to.
|
||||
tools: List of callable functions the agent can use as tools.
|
||||
input_guardrails: List of input validation guardrails.
|
||||
output_guardrails: List of output validation guardrails.
|
||||
model_config: Configuration for the underlying language model.
|
||||
session_config: Configuration for session management.
|
||||
api_key: OpenAI API key. If not provided, will use OPENAI_API_KEY env var.
|
||||
streaming: Whether to use streaming responses for real-time output.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set up API key
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
elif not os.getenv("OPENAI_API_KEY"):
|
||||
logger.warning("No OpenAI API key provided. Set OPENAI_API_KEY environment variable.")
|
||||
|
||||
# Create or use existing agent
|
||||
if agent:
|
||||
self._agent = agent
|
||||
else:
|
||||
# Convert sequences to lists and handle string instructions
|
||||
agent_handoffs: List[Any] = list(handoffs) if handoffs else []
|
||||
agent_tools: List[Any] = list(tools) if tools else []
|
||||
agent_input_guardrails: List[Any] = list(input_guardrails) if input_guardrails else []
|
||||
agent_output_guardrails: List[Any] = (
|
||||
list(output_guardrails) if output_guardrails else []
|
||||
)
|
||||
|
||||
# Handle instructions - convert sequence to string if needed
|
||||
if isinstance(instructions, str):
|
||||
agent_instructions = instructions
|
||||
else:
|
||||
agent_instructions = " ".join(str(instr) for instr in instructions)
|
||||
|
||||
self._agent = Agent(
|
||||
name=name,
|
||||
instructions=agent_instructions,
|
||||
handoffs=agent_handoffs,
|
||||
tools=agent_tools,
|
||||
input_guardrails=agent_input_guardrails,
|
||||
output_guardrails=agent_output_guardrails,
|
||||
model=model_config.get("model", "gpt-4o") if model_config else "gpt-4o",
|
||||
)
|
||||
|
||||
self._streaming = streaming
|
||||
self._session_config = session_config or {}
|
||||
self._current_session = None
|
||||
self._accumulated_text = ""
|
||||
|
||||
# Set model name for metrics
|
||||
if model_config and "model" in model_config:
|
||||
self.set_model_name(model_config["model"])
|
||||
else:
|
||||
self.set_model_name("gpt-4o") # Default model
|
||||
|
||||
logger.info(f"Initialized OpenAI Agent service: {self._agent.name}")
|
||||
|
||||
@property
|
||||
def agent(self) -> Agent:
|
||||
"""Get the underlying OpenAI Agent.
|
||||
|
||||
Returns:
|
||||
The configured Agent instance.
|
||||
"""
|
||||
return self._agent
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIAgentContextAggregatorPair:
|
||||
"""Create OpenAI-specific context aggregators for agent interactions.
|
||||
|
||||
Creates a pair of context aggregators optimized for OpenAI Agent interactions,
|
||||
including support for function calls, tool usage, and conversation management.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
OpenAIAgentContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIAgentContextAggregatorPair.
|
||||
"""
|
||||
user = OpenAIAgentUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAgentAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIAgentContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def update_agent_config(
|
||||
self,
|
||||
*,
|
||||
instructions: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Update agent configuration dynamically.
|
||||
|
||||
Args:
|
||||
instructions: New system instructions for the agent.
|
||||
model_config: Updated model configuration.
|
||||
**kwargs: Additional agent configuration parameters.
|
||||
"""
|
||||
if instructions:
|
||||
self._agent.instructions = instructions
|
||||
logger.info(f"Updated agent instructions for {self._agent.name}")
|
||||
|
||||
if model_config:
|
||||
# Note: OpenAI Agents SDK handles model configuration during agent creation
|
||||
# We can't update model_config after agent is created, but we can update our model name
|
||||
if "model" in model_config:
|
||||
self.set_model_name(model_config["model"])
|
||||
logger.info(f"Updated model config for {self._agent.name}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the OpenAI Agent service.
|
||||
|
||||
Initializes the agent session and prepares for processing.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
logger.info(f"Starting OpenAI Agent service: {self._agent.name}")
|
||||
await super().start(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the OpenAI Agent service.
|
||||
|
||||
Cleans up resources and ends the current session.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
logger.info(f"Stopping OpenAI Agent service: {self._agent.name}")
|
||||
await super().stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the OpenAI Agent service.
|
||||
|
||||
Cancels any ongoing operations.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
logger.info(f"Cancelling OpenAI Agent service: {self._agent.name}")
|
||||
await super().cancel(frame)
|
||||
|
||||
@override
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
"""Process frames and handle agent interactions.
|
||||
|
||||
Processes OpenAILLMContextFrame and TextFrame by running them through the OpenAI Agent
|
||||
and streams the results back as LLM frames.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
# Process context frame through the agent
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
# Extract the latest user message from the context
|
||||
messages = frame.context.get_messages()
|
||||
if messages:
|
||||
# Get the last user message
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Extract text from content array
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
user_input = " ".join(text_parts)
|
||||
else:
|
||||
user_input = str(content)
|
||||
|
||||
if user_input.strip():
|
||||
await self._process_agent_request(user_input)
|
||||
break
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent context: {e}")
|
||||
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
|
||||
elif isinstance(frame, TextFrame):
|
||||
# Process text input through the agent directly (for backwards compatibility)
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self._process_agent_request(frame.text)
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing agent request: {e}")
|
||||
await self.push_error(ErrorFrame(f"Agent processing error: {e}"))
|
||||
else:
|
||||
# For frames we don't handle, pass them through with direction
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_agent_request(self, input_text: str):
|
||||
"""Process an agent request and stream the results.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
logger.debug(f"Processing agent request: {input_text}")
|
||||
|
||||
if self._streaming:
|
||||
await self._process_streaming_response(input_text)
|
||||
else:
|
||||
await self._process_non_streaming_response(input_text)
|
||||
|
||||
async def _process_streaming_response(self, input_text: str):
|
||||
"""Process a streaming agent response.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
try:
|
||||
# Run the agent with streaming
|
||||
result: RunResultStreaming = Runner.run_streamed(
|
||||
self._agent, input_text, context=self._session_config
|
||||
)
|
||||
|
||||
has_streaming_deltas = False
|
||||
|
||||
# Process the stream events
|
||||
async for event in result.stream_events():
|
||||
if event.type == "raw_response_event":
|
||||
# Handle token-by-token streaming
|
||||
# Only check for delta on events that are known to have it
|
||||
if hasattr(event.data, "delta") and getattr(event.data, "delta", None):
|
||||
delta_text = getattr(event.data, "delta", "")
|
||||
if delta_text:
|
||||
has_streaming_deltas = True
|
||||
self._accumulated_text += delta_text
|
||||
await self.push_frame(LLMTextFrame(text=delta_text))
|
||||
|
||||
elif event.type == "run_item_stream_event":
|
||||
# Handle completed items
|
||||
if event.item.type == "message_output_item":
|
||||
# Only process complete message if we didn't get streaming deltas
|
||||
if not has_streaming_deltas:
|
||||
message_text = self._extract_message_text(event.item)
|
||||
logger.debug(
|
||||
f"Processing complete message (no deltas): {message_text[:50]}..."
|
||||
if len(message_text) > 50
|
||||
else f"Processing complete message: {message_text}"
|
||||
)
|
||||
if message_text:
|
||||
await self.push_frame(LLMTextFrame(text=message_text))
|
||||
|
||||
elif event.item.type == "tool_call_item":
|
||||
# Use getattr for safe attribute access
|
||||
tool_name = getattr(event.item, "tool_name", "unknown")
|
||||
logger.debug(f"Tool called: {tool_name}")
|
||||
|
||||
elif event.item.type == "tool_call_output_item":
|
||||
output = getattr(event.item, "output", "no output")
|
||||
logger.debug(f"Tool output: {output}")
|
||||
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
logger.debug(f"Agent updated: {event.new_agent.name}")
|
||||
|
||||
# Reset accumulated text for next request
|
||||
self._accumulated_text = ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming response: {e}")
|
||||
raise
|
||||
|
||||
async def _process_non_streaming_response(self, input_text: str):
|
||||
"""Process a non-streaming agent response.
|
||||
|
||||
Args:
|
||||
input_text: The user input text to process.
|
||||
"""
|
||||
try:
|
||||
# Run the agent without streaming
|
||||
result: RunResult = await Runner.run(
|
||||
self._agent, input_text, context=self._session_config
|
||||
)
|
||||
|
||||
# Send the final output
|
||||
if result.final_output:
|
||||
await self.push_frame(LLMTextFrame(text=result.final_output))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in non-streaming response: {e}")
|
||||
raise
|
||||
|
||||
def _extract_message_text(self, item) -> str:
|
||||
"""Extract text from a message output item.
|
||||
|
||||
Args:
|
||||
item: The message output item from the agent.
|
||||
|
||||
Returns:
|
||||
The extracted text content.
|
||||
"""
|
||||
try:
|
||||
# Handle OpenAI Agents SDK MessageOutputItem format
|
||||
if hasattr(item, "raw_item") and hasattr(item.raw_item, "content"):
|
||||
content = item.raw_item.content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for content_part in content:
|
||||
if hasattr(content_part, "text"):
|
||||
text_parts.append(content_part.text)
|
||||
elif (
|
||||
isinstance(content_part, dict)
|
||||
and content_part.get("type") == "output_text"
|
||||
):
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
elif isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
# Handle direct content attribute
|
||||
elif hasattr(item, "content"):
|
||||
if isinstance(item.content, str):
|
||||
return item.content
|
||||
elif isinstance(item.content, list):
|
||||
# Extract text from content array
|
||||
text_parts = []
|
||||
for content_part in item.content:
|
||||
if isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
text_parts.append(content_part.get("text", ""))
|
||||
elif isinstance(content_part, str):
|
||||
text_parts.append(content_part)
|
||||
return "".join(text_parts)
|
||||
|
||||
# If no text content found, return empty string instead of str(item)
|
||||
logger.debug(f"No extractable text content found in item: {type(item)}")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from message item: {e}")
|
||||
return ""
|
||||
|
||||
async def add_tool(self, tool_function: ToolLike):
|
||||
"""Add a tool function to the agent.
|
||||
|
||||
Args:
|
||||
tool_function: A callable function or Tool object to add as a tool.
|
||||
"""
|
||||
if hasattr(self._agent, "tools"):
|
||||
# Cast to Any to handle the type variance issue
|
||||
tools_list: List[Any] = self._agent.tools
|
||||
tools_list.append(tool_function)
|
||||
tool_name = getattr(
|
||||
tool_function, "__name__", getattr(tool_function, "name", "unknown")
|
||||
)
|
||||
logger.info(f"Added tool {tool_name} to agent {self._agent.name}")
|
||||
|
||||
async def add_handoff_agent(self, agent: AgentLike):
|
||||
"""Add a handoff agent.
|
||||
|
||||
Args:
|
||||
agent: Another Agent instance or handoff object that this agent can hand off to.
|
||||
"""
|
||||
if hasattr(self._agent, "handoffs"):
|
||||
# Cast to Any to handle the type variance issue
|
||||
handoffs_list: List[Any] = self._agent.handoffs
|
||||
handoffs_list.append(agent)
|
||||
agent_name = getattr(agent, "name", "unknown")
|
||||
logger.info(f"Added handoff agent {agent_name} to agent {self._agent.name}")
|
||||
|
||||
def get_session_context(self) -> Dict[str, Any]:
|
||||
"""Get the current session context.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current session context.
|
||||
"""
|
||||
return self._session_config.copy()
|
||||
|
||||
def update_session_context(self, context: Dict[str, Any]):
|
||||
"""Update the session context.
|
||||
|
||||
Args:
|
||||
context: Dictionary of context updates to apply.
|
||||
"""
|
||||
self._session_config.update(context)
|
||||
logger.debug(f"Updated session context for agent {self._agent.name}")
|
||||
|
||||
|
||||
class OpenAIAgentUserContextAggregator(LLMUserContextAggregator):
|
||||
"""OpenAI Agent-specific user context aggregator.
|
||||
|
||||
Handles aggregation of user messages for OpenAI Agent services.
|
||||
Inherits all functionality from the base LLMUserContextAggregator.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAgentAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""OpenAI Agent-specific assistant context aggregator.
|
||||
|
||||
Handles aggregation of assistant messages for OpenAI Agent services,
|
||||
with specialized support for OpenAI's function calling format,
|
||||
tool usage tracking, and agent interaction management.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -31,7 +32,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -366,7 +366,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if not self._audio_input_paused:
|
||||
await self._send_user_audio(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
@@ -716,14 +716,12 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self._start_interruption() # cancels this processor task
|
||||
await self.push_frame(StartInterruptionFrame()) # cancels downstream tasks
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._stop_interruption()
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -32,7 +33,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -364,7 +364,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if not self._audio_input_paused:
|
||||
await self._send_user_audio(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
@@ -658,14 +658,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self._start_interruption() # cancels this processor task
|
||||
await self.push_frame(StartInterruptionFrame()) # cancels downstream tasks
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._stop_interruption()
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):
|
||||
|
||||
@@ -25,8 +25,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -312,7 +312,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by stopping metrics and clearing request ID."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
|
||||
@@ -24,15 +24,14 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import AudioContextWordTTSService, TTSService
|
||||
from pipecat.transcriptions import language
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
@@ -280,7 +279,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by clearing current context."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
@@ -375,7 +374,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -455,7 +455,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -15,8 +15,8 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
OutputImageRawFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -179,7 +179,7 @@ class SimliVideoService(FrameProcessor):
|
||||
return
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
elif isinstance(frame, (StartInterruptionFrame, UserStartedSpeakingFrame)):
|
||||
elif isinstance(frame, (InterruptionFrame, UserStartedSpeakingFrame)):
|
||||
if not self._previously_interrupted:
|
||||
await self._simli_client.clearBuffer()
|
||||
self._previously_interrupted = self._is_trinity_avatar
|
||||
|
||||
@@ -19,7 +19,6 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
@@ -749,14 +748,13 @@ class SpeechmaticsSTTService(STTService):
|
||||
return
|
||||
|
||||
# Frames to send
|
||||
upstream_frames: list[Frame] = []
|
||||
downstream_frames: list[Frame] = []
|
||||
|
||||
# If VAD is enabled, then send a speaking frame
|
||||
if self._params.enable_vad and not self._is_speaking:
|
||||
logger.debug("User started speaking")
|
||||
self._is_speaking = True
|
||||
upstream_frames += [BotInterruptionFrame()]
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
downstream_frames += [UserStartedSpeakingFrame()]
|
||||
|
||||
# If final, then re-parse into TranscriptionFrame
|
||||
@@ -794,10 +792,6 @@ class SpeechmaticsSTTService(STTService):
|
||||
self._is_speaking = False
|
||||
downstream_frames += [UserStoppedSpeakingFrame()]
|
||||
|
||||
# Send UPSTREAM frames
|
||||
for frame in upstream_frames:
|
||||
await self.push_frame(frame, FrameDirection.UPSTREAM)
|
||||
|
||||
# Send the DOWNSTREAM frames
|
||||
for frame in downstream_frames:
|
||||
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
@@ -23,12 +23,12 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputImageRawFrame,
|
||||
OutputTransportReadyFrame,
|
||||
SpeechOutputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
@@ -222,7 +222,7 @@ class TavusVideoService(AIService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
|
||||
@@ -20,10 +20,10 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -309,7 +309,7 @@ class TTSService(AIService):
|
||||
and not isinstance(frame, TranscriptionFrame)
|
||||
):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
@@ -367,14 +367,14 @@ class TTSService(AIService):
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
if self._push_stop_frames and (
|
||||
isinstance(frame, StartInterruptionFrame)
|
||||
isinstance(frame, InterruptionFrame)
|
||||
or isinstance(frame, TTSStartedFrame)
|
||||
or isinstance(frame, TTSAudioRawFrame)
|
||||
or isinstance(frame, TTSStoppedFrame)
|
||||
):
|
||||
await self._stop_frame_queue.put(frame)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
self._processing_text = False
|
||||
await self._text_aggregator.handle_interruption()
|
||||
for filter in self._text_filters:
|
||||
@@ -438,7 +438,7 @@ class TTSService(AIService):
|
||||
)
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
has_started = True
|
||||
elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
elif isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
has_started = False
|
||||
except asyncio.TimeoutError:
|
||||
if has_started:
|
||||
@@ -523,7 +523,7 @@ class WordTTSService(TTSService):
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._llm_response_started = False
|
||||
self.reset_word_timestamps()
|
||||
@@ -613,7 +613,7 @@ class InterruptibleTTSService(WebsocketTTSService):
|
||||
# user interrupts we need to reconnect.
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
if self._bot_speaking:
|
||||
await self._disconnect()
|
||||
@@ -685,7 +685,7 @@ class InterruptibleWordTTSService(WebsocketWordTTSService):
|
||||
# user interrupts we need to reconnect.
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
if self._bot_speaking:
|
||||
await self._disconnect()
|
||||
@@ -813,7 +813,7 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._stop_audio_context_task()
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self._stop_audio_context_task()
|
||||
self._create_audio_context_task()
|
||||
|
||||
@@ -128,7 +128,7 @@ async def run_test(
|
||||
expected_up_frames: Optional[Sequence[type]] = None,
|
||||
ignore_start: bool = True,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
start_metadata: Optional[Dict[str, Any]] = None,
|
||||
pipeline_params: Optional[PipelineParams] = None,
|
||||
send_end_frame: bool = True,
|
||||
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
|
||||
"""Run a test pipeline with the specified processor and validate frame flow.
|
||||
@@ -144,7 +144,7 @@ async def run_test(
|
||||
expected_up_frames: Expected frame types flowing upstream (optional).
|
||||
ignore_start: Whether to ignore StartFrames in frame validation.
|
||||
observers: Optional list of observers to attach to the pipeline.
|
||||
start_metadata: Optional metadata to include with the StartFrame.
|
||||
pipeline_params: Optional pipeline parameters.
|
||||
send_end_frame: Whether to send an EndFrame at the end of the test.
|
||||
|
||||
Returns:
|
||||
@@ -154,7 +154,7 @@ async def run_test(
|
||||
AssertionError: If the received frames don't match the expected frame types.
|
||||
"""
|
||||
observers = observers or []
|
||||
start_metadata = start_metadata or {}
|
||||
pipeline_params = pipeline_params or PipelineParams()
|
||||
|
||||
received_up = asyncio.Queue()
|
||||
received_down = asyncio.Queue()
|
||||
@@ -173,7 +173,7 @@ async def run_test(
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(start_metadata=start_metadata),
|
||||
params=pipeline_params,
|
||||
observers=observers,
|
||||
cancel_on_idle_timeout=False,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.audio.turn.base_turn_analyzer import (
|
||||
)
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -36,7 +35,6 @@ from pipecat.frames.frames import (
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopFrame,
|
||||
SystemFrame,
|
||||
UserSpeakingFrame,
|
||||
@@ -289,8 +287,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_bot_interruption(frame)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._handle_bot_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -335,13 +331,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
|
||||
"""Handle bot interruption frames."""
|
||||
logger.debug("Bot interruption")
|
||||
if self.interruptions_allowed:
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
|
||||
async def _handle_user_interruption(self, vad_state: VADState, emulated: bool = False):
|
||||
"""Handle user interruption events based on speaking state."""
|
||||
if vad_state == VADState.SPEAKING:
|
||||
@@ -353,7 +342,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
# Only push StartInterruptionFrame if:
|
||||
# Only push InterruptionFrame if:
|
||||
# 1. No interruption config is set, OR
|
||||
# 2. Interruption config is set but bot is not speaking
|
||||
should_push_immediate_interruption = (
|
||||
@@ -362,11 +351,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if should_push_immediate_interruption and self.interruptions_allowed:
|
||||
await self._start_interruption()
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task) to stop everything, specially at the output
|
||||
# transport.
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
elif self.interruption_strategies and self._bot_speaking:
|
||||
logger.debug(
|
||||
"User started speaking while bot is speaking with interruption config - "
|
||||
@@ -381,9 +366,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
if self.interruptions_allowed:
|
||||
await self._stop_interruption()
|
||||
|
||||
#
|
||||
# Handle bot speaking state
|
||||
#
|
||||
|
||||
@@ -30,6 +30,7 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputTransportMessageUrgentFrame,
|
||||
InterruptionFrame,
|
||||
MixerControlFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputDTMFFrame,
|
||||
@@ -39,7 +40,6 @@ from pipecat.frames.frames import (
|
||||
SpeechOutputAudioRawFrame,
|
||||
SpriteFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
@@ -287,9 +287,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
#
|
||||
# System frames (like StartInterruptionFrame) are pushed
|
||||
# immediately. Other frames require order so they are put in the sink
|
||||
# queue.
|
||||
# System frames (like InterruptionFrame) are pushed immediately. Other
|
||||
# frames require order so they are put in the sink queue.
|
||||
#
|
||||
if isinstance(frame, StartFrame):
|
||||
# Push StartFrame before start(), because we want StartFrame to be
|
||||
@@ -299,7 +298,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._handle_frame(frame)
|
||||
elif isinstance(frame, TransportMessageUrgentFrame) and not isinstance(
|
||||
@@ -340,7 +339,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
|
||||
sender = self._media_senders[frame.transport_destination]
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await sender.handle_interruptions(frame)
|
||||
elif isinstance(frame, OutputAudioRawFrame):
|
||||
await sender.handle_audio_frame(frame)
|
||||
@@ -491,7 +490,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._cancel_clock_task()
|
||||
await self._cancel_video_task()
|
||||
|
||||
async def handle_interruptions(self, _: StartInterruptionFrame):
|
||||
async def handle_interruptions(self, _: InterruptionFrame):
|
||||
"""Handle interruption events by restarting tasks and clearing buffers.
|
||||
|
||||
Args:
|
||||
@@ -672,7 +671,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
frame = self._audio_queue.get_nowait()
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
frame.audio = await self._mixer.mix(frame.audio)
|
||||
last_frame_time = time.time()
|
||||
last_frame_time = time.time()
|
||||
yield frame
|
||||
except asyncio.QueueEmpty:
|
||||
# Notify the bot stopped speaking upstream if necessary.
|
||||
|
||||
@@ -478,7 +478,11 @@ class SmallWebRTCClient:
|
||||
self._screen_video_track = None
|
||||
self._audio_output_track = None
|
||||
self._video_output_track = None
|
||||
await self._callbacks.on_client_disconnected(self._webrtc_connection)
|
||||
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._closing:
|
||||
await self._callbacks.on_client_disconnected(self._webrtc_connection)
|
||||
|
||||
async def _handle_app_message(self, message: Any):
|
||||
"""Handle incoming application messages."""
|
||||
|
||||
@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -618,7 +618,7 @@ class TavusOutputTransport(BaseOutputTransport):
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions()
|
||||
|
||||
async def _handle_interruptions(self):
|
||||
|
||||
@@ -26,9 +26,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -138,7 +138,6 @@ class FastAPIWebsocketClient:
|
||||
):
|
||||
logger.warning("Closing already disconnected websocket!")
|
||||
self._closing = True
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect the WebSocket client."""
|
||||
@@ -152,8 +151,6 @@ class FastAPIWebsocketClient:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception while closing the websocket: {e}")
|
||||
finally:
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def trigger_client_disconnected(self):
|
||||
"""Trigger the client disconnected callback."""
|
||||
@@ -298,7 +295,10 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
await self._client.trigger_client_disconnected()
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._client.is_closing:
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
@@ -398,7 +398,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._write_frame(frame)
|
||||
self._next_send_time = 0
|
||||
|
||||
@@ -446,6 +446,9 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
"""Serialize and send a frame through the WebSocket."""
|
||||
if self._client.is_closing or not self._client.is_connected:
|
||||
return
|
||||
|
||||
if not self._params.serializer:
|
||||
return
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -334,7 +334,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._write_frame(frame)
|
||||
self._next_send_time = 0
|
||||
|
||||
|
||||
172
test_openai_agent.py
Normal file
172
test_openai_agent.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Simple test script for OpenAI Agent service."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock the OpenAI API key for testing
|
||||
os.environ["OPENAI_API_KEY"] = "test-key-for-testing"
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai_agent import OpenAIAgentService
|
||||
|
||||
|
||||
async def test_basic_functionality():
|
||||
"""Test basic OpenAI Agent service functionality."""
|
||||
print("🧪 Testing OpenAI Agent Service...")
|
||||
|
||||
# Create a simple weather tool for testing
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"The weather in {location} is sunny and 22°C."
|
||||
|
||||
try:
|
||||
# Create the service
|
||||
print("📋 Creating OpenAI Agent service...")
|
||||
service = OpenAIAgentService(
|
||||
name="Test Assistant",
|
||||
instructions="You are a helpful test assistant.",
|
||||
tools=[get_weather],
|
||||
api_key="test-key",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
print(f"✅ Service created successfully!")
|
||||
print(f" - Agent name: {service.agent.name}")
|
||||
print(f" - Model name: {service.model_name}")
|
||||
print(f" - Streaming enabled: {service._streaming}")
|
||||
|
||||
# Test basic configuration
|
||||
print("⚙️ Testing configuration updates...")
|
||||
service.update_agent_config(
|
||||
instructions="Updated test instructions",
|
||||
model_config={"model": "gpt-4o", "temperature": 0.5},
|
||||
)
|
||||
|
||||
print(f"✅ Configuration updated!")
|
||||
print(f" - New instructions: {service.agent.instructions}")
|
||||
print(f" - New model: {service.model_name}")
|
||||
|
||||
# Test session context
|
||||
print("💾 Testing session context...")
|
||||
service.update_session_context({"user_id": "test-user", "session": "test-session"})
|
||||
context = service.get_session_context()
|
||||
|
||||
print(f"✅ Session context managed!")
|
||||
print(f" - Context keys: {list(context.keys())}")
|
||||
|
||||
# Test adding tools
|
||||
print("🔧 Testing tool management...")
|
||||
|
||||
def get_time() -> str:
|
||||
"""Get current time."""
|
||||
return "The current time is 3:00 PM."
|
||||
|
||||
await service.add_tool(get_time)
|
||||
print(f"✅ Tool added successfully!")
|
||||
|
||||
print("\n🎉 All basic functionality tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_frame_processing():
|
||||
"""Test frame processing with mocked responses."""
|
||||
print("\n🔄 Testing frame processing...")
|
||||
|
||||
try:
|
||||
# Mock the Runner to avoid actual API calls
|
||||
with patch("pipecat.services.openai_agent.agent_service.Runner") as mock_runner:
|
||||
# Set up mock responses
|
||||
mock_stream_result = MagicMock()
|
||||
|
||||
# Mock stream events
|
||||
async def mock_stream_events():
|
||||
# Simulate streaming response
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="Hello "))
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="from "))
|
||||
yield MagicMock(type="raw_response_event", data=MagicMock(delta="agent!"))
|
||||
|
||||
# Simulate completed message
|
||||
mock_item = MagicMock()
|
||||
mock_item.type = "message_output_item"
|
||||
mock_item.content = "Hello from agent!"
|
||||
yield MagicMock(type="run_item_stream_event", item=mock_item)
|
||||
|
||||
mock_stream_result.stream_events.return_value = mock_stream_events()
|
||||
mock_runner.run_streamed.return_value = mock_stream_result
|
||||
|
||||
# Create service with mocked runner
|
||||
service = OpenAIAgentService(
|
||||
name="Test Assistant",
|
||||
instructions="You are a helpful test assistant.",
|
||||
api_key="test-key",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# Collect output frames
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
print(f" 📤 Frame: {type(frame).__name__}")
|
||||
if hasattr(frame, "text"):
|
||||
print(f" Text: '{frame.text}'")
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
print("📝 Processing text frame...")
|
||||
text_frame = TextFrame("Hello, how are you?")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait for async processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
print(f"✅ Frame processing completed!")
|
||||
print(f" - Generated {len(output_frames)} output frames")
|
||||
|
||||
# Check if we got expected frame types
|
||||
frame_types = [type(frame).__name__ for frame in output_frames]
|
||||
print(f" - Frame types: {frame_types}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Frame processing test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting OpenAI Agent Service Tests\n")
|
||||
|
||||
try:
|
||||
# Run basic functionality tests
|
||||
basic_test = await test_basic_functionality()
|
||||
|
||||
# Run frame processing tests
|
||||
frame_test = await test_frame_processing()
|
||||
|
||||
# Summary
|
||||
print(f"\n📊 Test Results:")
|
||||
print(f" - Basic functionality: {'✅ PASS' if basic_test else '❌ FAIL'}")
|
||||
print(f" - Frame processing: {'✅ PASS' if frame_test else '❌ FAIL'}")
|
||||
|
||||
if basic_test and frame_test:
|
||||
print(f"\n🎉 All tests passed! The OpenAI Agent service is working correctly.")
|
||||
else:
|
||||
print(f"\n⚠️ Some tests failed. Please check the output above.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test suite failed with error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
33
test_simple_agent.py
Normal file
33
test_simple_agent.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# Test the actual agents package API
|
||||
try:
|
||||
from agents import Agent, run
|
||||
|
||||
# Create a simple agent
|
||||
agent = Agent(
|
||||
name="test-agent",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
print("✅ Agent created successfully!")
|
||||
print(f"Agent name: {agent.name}")
|
||||
|
||||
# Test a simple conversation
|
||||
async def test_agent():
|
||||
result = await run(agent, "Hello, how are you?")
|
||||
print(f"Agent response: {result}")
|
||||
|
||||
# Run the test
|
||||
asyncio.run(test_agent())
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -8,25 +8,31 @@ import json
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -36,6 +42,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.anthropic.llm import (
|
||||
AnthropicAssistantContextAggregator,
|
||||
AnthropicLLMContext,
|
||||
@@ -481,6 +488,103 @@ class BaseTestUserContextAggregator:
|
||||
)
|
||||
self.check_message_content(context, 0, "How are you?")
|
||||
|
||||
async def test_min_words_interruption_strategy_one_word(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
class ContextProcessor(FrameProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.context_received = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
self.context_received = True
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
context_processor = ContextProcessor()
|
||||
pipeline = Pipeline([aggregator, context_processor])
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Can", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
pipeline_params=PipelineParams(
|
||||
interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)]
|
||||
),
|
||||
)
|
||||
assert not context_processor.context_received
|
||||
|
||||
async def test_min_words_interruption_strategy_two_words(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
class ContextProcessor(FrameProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.context_received = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
self.context_received = True
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
self.context_received = False
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
context_processor = ContextProcessor()
|
||||
pipeline = Pipeline([aggregator, context_processor])
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Can you", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_up_frames = [InterruptionTaskFrame]
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_up_frames=expected_up_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
pipeline_params=PipelineParams(
|
||||
interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)]
|
||||
),
|
||||
)
|
||||
self.check_message_content(context, 0, "Can you")
|
||||
# If the context is not received or it has been cleared by the
|
||||
# interruption then we have an issue.
|
||||
assert context_processor.context_received
|
||||
|
||||
|
||||
class BaseTestAssistantContextAggreagator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
@@ -618,7 +722,7 @@ class BaseTestAssistantContextAggreagator:
|
||||
TextFrame(text="Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
SleepFrame(AGGREGATION_SLEEP),
|
||||
StartInterruptionFrame(),
|
||||
InterruptionFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
@@ -626,7 +730,7 @@ class BaseTestAssistantContextAggreagator:
|
||||
]
|
||||
expected_down_frames = [
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
|
||||
@@ -10,6 +10,7 @@ from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator
|
||||
@@ -28,6 +29,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
@@ -59,9 +61,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # First aggregation "12"
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame, # Second aggregation "3"
|
||||
]
|
||||
|
||||
@@ -93,10 +97,12 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # "12#"
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # "45"
|
||||
]
|
||||
@@ -125,6 +131,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # Should flush before EndFrame
|
||||
EndFrame,
|
||||
@@ -152,6 +159,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame,
|
||||
]
|
||||
@@ -178,6 +186,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -214,7 +223,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
|
||||
# All the InputDTMFFrames plus one TranscriptionFrame
|
||||
expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame]
|
||||
expected_down_frames = (
|
||||
[InputDTMFFrame, InterruptionFrame]
|
||||
+ [InputDTMFFrame] * (len(frames_to_send) - 1)
|
||||
+ [TranscriptionFrame]
|
||||
)
|
||||
|
||||
received_down_frames, _ = await run_test(
|
||||
aggregator,
|
||||
|
||||
998
tests/test_get_llm_invocation_params.py
Normal file
998
tests/test_get_llm_invocation_params.py
Normal file
@@ -0,0 +1,998 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for LLM adapters' get_llm_invocation_params() method.
|
||||
|
||||
These tests focus specifically on the "messages" field generation for different adapters, ensuring:
|
||||
|
||||
For OpenAI adapter:
|
||||
1. LLMStandardMessage objects are passed through unchanged
|
||||
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
|
||||
3. Complex message structures (like multi-part content) are preserved
|
||||
4. System instructions are preserved throughout messages at any position
|
||||
|
||||
For Gemini adapter:
|
||||
1. LLMStandardMessage objects are converted to Gemini Content format
|
||||
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
|
||||
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
|
||||
4. System messages are extracted as system_instruction (without duplication)
|
||||
5. Single system instruction is converted to user message when no other messages exist
|
||||
6. Multiple system instructions: first extracted, later ones converted to user messages
|
||||
|
||||
For Anthropic adapter:
|
||||
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
|
||||
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
|
||||
For AWS Bedrock adapter:
|
||||
1. LLMStandardMessage objects are converted to AWS Bedrock format
|
||||
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from google.genai.types import Content, Part
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = OpenAILLMAdapter()
|
||||
|
||||
def test_standard_messages_passed_through_unchanged(self):
|
||||
"""Test that LLMStandardMessage objects are passed through unchanged to OpenAI params."""
|
||||
# Create standard messages (OpenAI format)
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify messages are passed through unchanged
|
||||
self.assertEqual(params["messages"], standard_messages)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# Verify content matches exactly
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
|
||||
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that OpenAI-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Gemini specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
self.adapter.create_llm_specific_message(
|
||||
{"role": "assistant", "content": "OpenAI specific response"}
|
||||
),
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only include standard messages and OpenAI-specific ones
|
||||
# (3 total: system, standard user, openai assistant)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# Verify the correct messages are included
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
self.assertEqual(params["messages"][1]["content"], "Standard user message")
|
||||
self.assertEqual(
|
||||
params["messages"][2], {"role": "assistant", "content": "OpenAI specific response"}
|
||||
)
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message content (like multi-part messages) is preserved."""
|
||||
# Create a message with complex content structure (text + image)
|
||||
complex_image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD..."},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with multiple text blocks
|
||||
multi_text_message = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me analyze this step by step:"},
|
||||
{"type": "text", "text": "1. First, I'll examine the visual elements"},
|
||||
{"type": "text", "text": "2. Then I'll provide my conclusions"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant that can analyze images."},
|
||||
complex_image_message,
|
||||
multi_text_message,
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify complex content is preserved
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][1], complex_image_message)
|
||||
self.assertEqual(params["messages"][2], multi_text_message)
|
||||
|
||||
# Verify the image message structure is maintained
|
||||
image_content = params["messages"][1]["content"]
|
||||
self.assertIsInstance(image_content, list)
|
||||
self.assertEqual(len(image_content), 2)
|
||||
self.assertEqual(image_content[0]["type"], "text")
|
||||
self.assertEqual(image_content[1]["type"], "image_url")
|
||||
|
||||
# Verify the multi-text message structure is maintained
|
||||
text_content = params["messages"][2]["content"]
|
||||
self.assertIsInstance(text_content, list)
|
||||
self.assertEqual(len(text_content), 3)
|
||||
for i, text_block in enumerate(text_content):
|
||||
self.assertEqual(text_block["type"], "text")
|
||||
self.assertEqual(text_content[0]["text"], "Let me analyze this step by step:")
|
||||
self.assertEqual(text_content[1]["text"], "1. First, I'll examine the visual elements")
|
||||
self.assertEqual(text_content[2]["text"], "2. Then I'll provide my conclusions")
|
||||
|
||||
def test_system_instructions_preserved_throughout_messages(self):
|
||||
"""Test that OpenAI adapter preserves system instructions sprinkled throughout messages."""
|
||||
# Create messages with system instructions at different positions
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."},
|
||||
{"role": "user", "content": "Tell me about Python."},
|
||||
{"role": "system", "content": "Use simple language."},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# OpenAI should preserve all messages unchanged, including multiple system messages
|
||||
self.assertEqual(len(params["messages"]), 7)
|
||||
|
||||
# Verify system messages are preserved at their original positions
|
||||
self.assertEqual(params["messages"][0]["role"], "system")
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
self.assertEqual(params["messages"][3]["role"], "system")
|
||||
self.assertEqual(params["messages"][3]["content"], "Remember to be concise.")
|
||||
|
||||
self.assertEqual(params["messages"][5]["role"], "system")
|
||||
self.assertEqual(params["messages"][5]["content"], "Use simple language.")
|
||||
|
||||
# Verify other messages remain unchanged
|
||||
self.assertEqual(params["messages"][1]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][4]["role"], "user")
|
||||
self.assertEqual(params["messages"][6]["role"], "assistant")
|
||||
|
||||
|
||||
class TestGeminiGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = GeminiLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_gemini_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to Gemini Content format."""
|
||||
# Create standard messages (OpenAI format)
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction is extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are converted to Gemini format (2 messages: user + model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertIsInstance(user_msg, Content)
|
||||
self.assertEqual(user_msg.role, "user")
|
||||
self.assertEqual(len(user_msg.parts), 1)
|
||||
self.assertEqual(user_msg.parts[0].text, "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant -> model)
|
||||
model_msg = params["messages"][1]
|
||||
self.assertIsInstance(model_msg, Content)
|
||||
self.assertEqual(model_msg.role, "model")
|
||||
self.assertEqual(len(model_msg.parts), 1)
|
||||
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Gemini-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific message"}
|
||||
),
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
self.adapter.create_llm_specific_message(
|
||||
Content(role="model", parts=[Part(text="Gemini specific response")]),
|
||||
),
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only include standard messages and Gemini-specific ones
|
||||
# (2 total: converted standard user + gemini model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify the correct messages are included
|
||||
self.assertEqual(params["messages"][0].role, "user")
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "Standard user message")
|
||||
|
||||
self.assertEqual(params["messages"][1].role, "model")
|
||||
self.assertEqual(params["messages"][1].parts[0].text, "Gemini specific response")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message content (like multi-part messages) is preserved and converted.
|
||||
|
||||
This test covers image, audio, and multi-text content conversion to Gemini format.
|
||||
"""
|
||||
# Create a message with complex content structure (text + image)
|
||||
# Using a minimal valid base64 image data
|
||||
complex_image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with multiple text blocks
|
||||
multi_text_message = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me analyze this step by step:"},
|
||||
{"type": "text", "text": "1. First, I'll examine the visual elements"},
|
||||
{"type": "text", "text": "2. Then I'll provide my conclusions"},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a message with audio input (text + audio)
|
||||
# Using a minimal valid base64 audio data (16 bytes of WAV header)
|
||||
audio_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you transcribe this audio?"},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=",
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can analyze images and audio.",
|
||||
},
|
||||
complex_image_message,
|
||||
multi_text_message,
|
||||
audio_message,
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(
|
||||
params["system_instruction"],
|
||||
"You are a helpful assistant that can analyze images and audio.",
|
||||
)
|
||||
|
||||
# Verify complex content is converted to Gemini format
|
||||
# Note: Gemini adapter may add system instruction back as user message in some cases
|
||||
self.assertGreaterEqual(len(params["messages"]), 3)
|
||||
|
||||
# Find the different message types
|
||||
user_with_image = None
|
||||
model_with_text = None
|
||||
user_with_audio = None
|
||||
|
||||
for msg in params["messages"]:
|
||||
if msg.role == "user" and len(msg.parts) == 2:
|
||||
# Check if it's image or audio based on the text content
|
||||
if hasattr(msg.parts[0], "text") and "image" in msg.parts[0].text:
|
||||
user_with_image = msg
|
||||
elif hasattr(msg.parts[0], "text") and "audio" in msg.parts[0].text:
|
||||
user_with_audio = msg
|
||||
elif msg.role == "model" and len(msg.parts) == 3:
|
||||
model_with_text = msg
|
||||
|
||||
# Verify the image message structure is converted properly
|
||||
self.assertIsNotNone(user_with_image, "Should have user message with image")
|
||||
self.assertEqual(len(user_with_image.parts), 2)
|
||||
|
||||
# First part should be text
|
||||
self.assertEqual(user_with_image.parts[0].text, "What's in this image?")
|
||||
|
||||
# Second part should be image data (converted to Blob)
|
||||
self.assertIsNotNone(user_with_image.parts[1].inline_data)
|
||||
self.assertEqual(user_with_image.parts[1].inline_data.mime_type, "image/jpeg")
|
||||
|
||||
# Verify the audio message structure is converted properly
|
||||
self.assertIsNotNone(user_with_audio, "Should have user message with audio")
|
||||
self.assertEqual(len(user_with_audio.parts), 2)
|
||||
|
||||
# First part should be text
|
||||
self.assertEqual(user_with_audio.parts[0].text, "Can you transcribe this audio?")
|
||||
|
||||
# Second part should be audio data (converted to Blob)
|
||||
self.assertIsNotNone(user_with_audio.parts[1].inline_data)
|
||||
self.assertEqual(user_with_audio.parts[1].inline_data.mime_type, "audio/wav")
|
||||
|
||||
# Verify the multi-text message structure is converted properly
|
||||
self.assertIsNotNone(model_with_text, "Should have model message with multi-text")
|
||||
self.assertEqual(len(model_with_text.parts), 3)
|
||||
|
||||
# All parts should be text
|
||||
expected_texts = [
|
||||
"Let me analyze this step by step:",
|
||||
"1. First, I'll examine the visual elements",
|
||||
"2. Then I'll provide my conclusions",
|
||||
]
|
||||
for i, expected_text in enumerate(expected_texts):
|
||||
self.assertEqual(model_with_text.parts[i].text, expected_text)
|
||||
|
||||
def test_single_system_instruction_converted_to_user(self):
|
||||
"""Test that when there's only a system instruction, it gets converted to user message."""
|
||||
# Create context with only a system message
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages=messages)
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System instruction should be extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# But since there are no other messages, it should also be added back as a user message
|
||||
self.assertEqual(len(params["messages"]), 1)
|
||||
self.assertEqual(params["messages"][0].role, "user")
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "You are a helpful assistant.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
# Create messages with multiple system instructions
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."},
|
||||
{"role": "user", "content": "Tell me about Python."},
|
||||
{"role": "system", "content": "Use simple language."},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages=messages)
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# First system instruction should be extracted
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 6 messages (original 7 minus 1 system instruction that was extracted)
|
||||
self.assertEqual(len(params["messages"]), 6)
|
||||
|
||||
# Find the converted system messages (should be user role now)
|
||||
converted_system_messages = []
|
||||
for msg in params["messages"]:
|
||||
if msg.role == "user" and (
|
||||
msg.parts[0].text == "Remember to be concise."
|
||||
or msg.parts[0].text == "Use simple language."
|
||||
):
|
||||
converted_system_messages.append(msg.parts[0].text)
|
||||
|
||||
# Should have 2 converted system messages
|
||||
self.assertEqual(len(converted_system_messages), 2)
|
||||
self.assertIn("Remember to be concise.", converted_system_messages)
|
||||
self.assertIn("Use simple language.", converted_system_messages)
|
||||
|
||||
# Verify that regular user and assistant messages are preserved
|
||||
user_messages = [msg for msg in params["messages"] if msg.role == "user"]
|
||||
model_messages = [msg for msg in params["messages"] if msg.role == "model"]
|
||||
|
||||
# Should have 4 user messages: 2 original + 2 converted from system
|
||||
self.assertEqual(len(user_messages), 4)
|
||||
# Should have 2 model messages (converted from assistant)
|
||||
self.assertEqual(len(model_messages), 2)
|
||||
|
||||
|
||||
class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = AnthropicLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_anthropic_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to Anthropic MessageParam format."""
|
||||
# Create standard messages
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify system instruction is extracted
|
||||
self.assertEqual(params["system"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are in the params (2 messages after system extraction)
|
||||
self.assertIn("messages", params)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertEqual(user_msg["content"], "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant)
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Anthropic-specific messages are included and others are filtered out."""
|
||||
# Create anthropic-specific message content
|
||||
anthropic_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/jpeg", "data": "fake_data"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(anthropic_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + anthropic-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message: merged user message (standard + anthropic-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + anthropic text + anthropic image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "image")
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First user message"},
|
||||
{"role": "user", "content": "Second user message"},
|
||||
{"role": "user", "content": "Third user message"},
|
||||
{"role": "assistant", "content": "First assistant message"},
|
||||
{"role": "assistant", "content": "Second assistant message"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should have 2 messages after merging (1 user, 1 assistant)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check merged user message
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "First user message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
|
||||
|
||||
# Check merged assistant message
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 2)
|
||||
self.assertEqual(assistant_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
|
||||
self.assertEqual(assistant_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
|
||||
|
||||
def test_empty_text_converted_to_empty_placeholder(self):
|
||||
"""Test that empty text content is converted to "(empty)" string."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""}, # Empty string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""}, # Empty text in list content
|
||||
{"type": "text", "text": "Valid text"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Check that empty string content was converted
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["content"], "(empty)")
|
||||
|
||||
# Check that empty text in list content was converted
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message structures (text + image) are properly converted to Anthropic format."""
|
||||
# Create a complex message with both text and image content
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,fake_image_data"},
|
||||
},
|
||||
{"type": "text", "text": "Please describe it in detail."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
complex_message,
|
||||
{"role": "assistant", "content": "I can see the image clearly."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify complex message structure is preserved and converted
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
|
||||
# Note: Anthropic adapter reorders single images to come before text, as per Anthropic docs
|
||||
# Check image part (should be moved to first position and converted from image_url to image)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "image")
|
||||
self.assertIn("source", user_msg["content"][0])
|
||||
self.assertEqual(user_msg["content"][0]["source"]["type"], "base64")
|
||||
self.assertEqual(user_msg["content"][0]["source"]["media_type"], "image/jpeg")
|
||||
self.assertEqual(user_msg["content"][0]["source"]["data"], "fake_image_data")
|
||||
|
||||
# Check first text part (moved to second position)
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
|
||||
|
||||
# Check second text part (moved to third position)
|
||||
self.assertEqual(user_msg["content"][2]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."}, # Later system message
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# System instruction should be extracted from first message
|
||||
self.assertEqual(params["system"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 3 messages remaining (system message was removed, later system converted to user)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"], "Hello")
|
||||
self.assertEqual(params["messages"][1]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][1]["content"], "Hi there!")
|
||||
|
||||
# Later system message should be converted to user role
|
||||
self.assertEqual(params["messages"][2]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["content"], "Remember to be concise.")
|
||||
|
||||
def test_single_system_message_converted_to_user(self):
|
||||
"""Test that a single system message is converted to user role when no other messages exist."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# System should be NOT_GIVEN since we only have one message
|
||||
from anthropic import NOT_GIVEN
|
||||
|
||||
self.assertEqual(params["system"], NOT_GIVEN)
|
||||
|
||||
# Single system message should be converted to user role
|
||||
self.assertEqual(len(params["messages"]), 1)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
|
||||
class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""Sets up a common adapter instance for all tests."""
|
||||
self.adapter = AWSBedrockLLMAdapter()
|
||||
|
||||
def test_standard_messages_converted_to_aws_bedrock_format(self):
|
||||
"""Test that LLMStandardMessage objects are converted to AWS Bedrock format."""
|
||||
# Create standard messages
|
||||
standard_messages: list[LLMStandardMessage] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=standard_messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction is extracted (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages are in the params (2 messages after system extraction)
|
||||
self.assertIn("messages", params)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check first message (user) - should be converted to AWS Bedrock format
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 1)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Hello, how are you?")
|
||||
|
||||
# Check second message (assistant) - should be converted to AWS Bedrock format
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 1)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
|
||||
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that AWS-specific messages are included and others are filtered out."""
|
||||
# Create aws-specific message content (which is what AWS Bedrock uses)
|
||||
aws_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello"},
|
||||
{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(message=aws_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + aws-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message: merged user message (standard + aws-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + aws text + aws image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertIn("image", user_msg["content"][2])
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
"""Test that consecutive messages with the same role are merged into multi-content blocks."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First user message"},
|
||||
{"role": "user", "content": "Second user message"},
|
||||
{"role": "user", "content": "Third user message"},
|
||||
{"role": "assistant", "content": "First assistant message"},
|
||||
{"role": "assistant", "content": "Second assistant message"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should have 2 messages after merging (1 user, 1 assistant)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# Check merged user message
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "First user message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Second user message")
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Third user message")
|
||||
|
||||
# Check merged assistant message
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(len(assistant_msg["content"]), 2)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "First assistant message")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Second assistant message")
|
||||
|
||||
def test_empty_text_converted_to_empty_placeholder(self):
|
||||
"""Test that empty text content is converted to "(empty)" string."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""}, # Empty string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""}, # Empty text in list content
|
||||
{"type": "text", "text": "Valid text"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Check that empty string content was converted
|
||||
user_msg = params["messages"][0]
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "(empty)")
|
||||
|
||||
# Check that empty text in list content was converted
|
||||
assistant_msg = params["messages"][1]
|
||||
self.assertIsInstance(assistant_msg["content"], list)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "(empty)")
|
||||
self.assertEqual(assistant_msg["content"][1]["text"], "Valid text")
|
||||
|
||||
def test_complex_message_content_preserved(self):
|
||||
"""Test that complex message structures (text + image) are properly converted to AWS Bedrock format."""
|
||||
# Create a complex message with both text and image content
|
||||
# Use a valid base64 string for the image
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Please describe it in detail."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = [
|
||||
complex_message,
|
||||
{"role": "assistant", "content": "I can see the image clearly."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify complex message structure is preserved and converted
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
|
||||
# Note: AWS Bedrock adapter reorders single images to come before text, like Anthropic
|
||||
# Check image part (should be moved to first position and converted from image_url to image)
|
||||
self.assertIn("image", user_msg["content"][0])
|
||||
self.assertEqual(user_msg["content"][0]["image"]["format"], "jpeg")
|
||||
self.assertIn("source", user_msg["content"][0]["image"])
|
||||
self.assertIn("bytes", user_msg["content"][0]["image"]["source"])
|
||||
|
||||
# Check first text part (moved to second position)
|
||||
self.assertEqual(user_msg["content"][1]["text"], "What do you see in this image?")
|
||||
|
||||
# Check second text part (moved to third position)
|
||||
self.assertEqual(user_msg["content"][2]["text"], "Please describe it in detail.")
|
||||
|
||||
def test_multiple_system_instructions_handling(self):
|
||||
"""Test that first system instruction is extracted, later ones converted to user messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "system", "content": "Remember to be concise."}, # Later system message
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System instruction should be extracted from first message (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# Should have 3 messages remaining (system message was removed, later system converted to user)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
self.assertEqual(params["messages"][0]["role"], "user")
|
||||
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
|
||||
self.assertEqual(params["messages"][1]["role"], "assistant")
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Hi there!")
|
||||
|
||||
# Later system message should be converted to user role
|
||||
self.assertEqual(params["messages"][2]["role"], "user")
|
||||
self.assertEqual(params["messages"][2]["content"][0]["text"], "Remember to be concise.")
|
||||
|
||||
def test_single_system_message_handling(self):
|
||||
"""Test that a single system message is extracted as system parameter and no messages remain."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# System should be extracted (in AWS Bedrock format)
|
||||
self.assertIsInstance(params["system"], list)
|
||||
self.assertEqual(len(params["system"]), 1)
|
||||
self.assertEqual(params["system"][0]["text"], "You are a helpful assistant.")
|
||||
|
||||
# No messages should remain after system extraction
|
||||
self.assertEqual(len(params["messages"]), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -7,10 +7,10 @@
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartInterruptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import LLMFullResponseAggregator
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
@@ -113,7 +113,7 @@ class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
LLMFullResponseStartFrame(),
|
||||
LLMTextFrame("Hello "),
|
||||
SleepFrame(),
|
||||
StartInterruptionFrame(),
|
||||
InterruptionFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
LLMTextFrame("Hello "),
|
||||
LLMTextFrame("there!"),
|
||||
@@ -122,7 +122,7 @@ class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames = [
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
LLMTextFrame,
|
||||
|
||||
286
tests/test_openai_agent_service.py
Normal file
286
tests/test_openai_agent_service.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for OpenAI Agent service."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add src to path for testing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""Mock Agent for testing."""
|
||||
|
||||
def __init__(self, name="Test Agent", instructions="Test instructions"):
|
||||
self.name = name
|
||||
self.instructions = instructions
|
||||
self.tools = []
|
||||
self.handoffs = []
|
||||
|
||||
|
||||
class MockRunResult:
|
||||
"""Mock RunResult for testing."""
|
||||
|
||||
def __init__(self, final_output="Test response"):
|
||||
self.final_output = final_output
|
||||
|
||||
|
||||
class MockStreamEvent:
|
||||
"""Mock StreamEvent for testing."""
|
||||
|
||||
def __init__(self, event_type, data=None, item=None):
|
||||
self.type = event_type
|
||||
self.data = data
|
||||
self.item = item
|
||||
|
||||
|
||||
class MockMessageItem:
|
||||
"""Mock message item for testing."""
|
||||
|
||||
def __init__(self, content="Test content"):
|
||||
self.type = "message_output_item"
|
||||
self.content = content
|
||||
|
||||
|
||||
class MockRunner:
|
||||
"""Mock Runner for testing."""
|
||||
|
||||
@staticmethod
|
||||
async def run(agent, input_text, context=None):
|
||||
return MockRunResult("Mocked response")
|
||||
|
||||
@staticmethod
|
||||
def run_streamed(agent, input_text, context=None):
|
||||
class MockStreamResult:
|
||||
async def stream_events(self):
|
||||
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="Test "))
|
||||
yield MockStreamEvent("raw_response_event", data=MagicMock(delta="response"))
|
||||
yield MockStreamEvent(
|
||||
"run_item_stream_event", item=MockMessageItem("Test response")
|
||||
)
|
||||
|
||||
return MockStreamResult()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_agents():
|
||||
"""Mock the OpenAI Agents SDK imports."""
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"agents": MagicMock(),
|
||||
"agents.stream_events": MagicMock(),
|
||||
"agents.result": MagicMock(),
|
||||
},
|
||||
):
|
||||
# Mock the classes and functions we need
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.return_value = MockAgent()
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.run = AsyncMock(return_value=MockRunResult())
|
||||
mock_runner.run_streamed = MagicMock(return_value=MockRunner.run_streamed(None, None))
|
||||
|
||||
with (
|
||||
patch("pipecat.services.openai_agent.agent_service.Agent", mock_agent),
|
||||
patch("pipecat.services.openai_agent.agent_service.Runner", mock_runner),
|
||||
):
|
||||
yield {
|
||||
"Agent": mock_agent,
|
||||
"Runner": mock_runner,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_init(mock_openai_agents):
|
||||
"""Test OpenAI Agent service initialization."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
|
||||
)
|
||||
|
||||
assert service.agent.name == "Test Agent"
|
||||
assert service._streaming is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_process_text_frame_streaming(mock_openai_agents):
|
||||
"""Test processing text frame with streaming enabled."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=True
|
||||
)
|
||||
|
||||
# Mock the push_frame method to capture output
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
text_frame = TextFrame("Hello, agent!")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait a bit for async processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that appropriate frames were generated
|
||||
assert len(output_frames) > 0
|
||||
assert any(isinstance(frame, LLMFullResponseStartFrame) for frame in output_frames)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_process_text_frame_non_streaming(mock_openai_agents):
|
||||
"""Test processing text frame with streaming disabled."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key", streaming=False
|
||||
)
|
||||
|
||||
# Mock the push_frame method to capture output
|
||||
output_frames = []
|
||||
|
||||
async def mock_push_frame(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
output_frames.append(frame)
|
||||
|
||||
service.push_frame = mock_push_frame
|
||||
|
||||
# Process a text frame
|
||||
text_frame = TextFrame("Hello, agent!")
|
||||
await service.process_frame(text_frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Wait a bit for async processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that appropriate frames were generated
|
||||
assert len(output_frames) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_update_config(mock_openai_agents):
|
||||
"""Test updating agent configuration."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Update configuration
|
||||
service.update_agent_config(
|
||||
instructions="Updated instructions", model_config={"model": "gpt-4o", "temperature": 0.7}
|
||||
)
|
||||
|
||||
assert service.agent.instructions == "Updated instructions"
|
||||
assert service.agent.model_config["model"] == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_session_context(mock_openai_agents):
|
||||
"""Test session context management."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent",
|
||||
instructions="Test instructions",
|
||||
api_key="test-key",
|
||||
session_config={"user_id": "test-user"},
|
||||
)
|
||||
|
||||
# Get initial context
|
||||
context = service.get_session_context()
|
||||
assert context["user_id"] == "test-user"
|
||||
|
||||
# Update context
|
||||
service.update_session_context({"session_id": "test-session"})
|
||||
|
||||
updated_context = service.get_session_context()
|
||||
assert updated_context["user_id"] == "test-user"
|
||||
assert updated_context["session_id"] == "test-session"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_add_tools(mock_openai_agents):
|
||||
"""Test adding tools to the agent."""
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Define a test tool
|
||||
def test_tool():
|
||||
return "test result"
|
||||
|
||||
# Add the tool
|
||||
await service.add_tool(test_tool)
|
||||
|
||||
# Check if tool was added (this depends on the mock implementation)
|
||||
assert hasattr(service.agent, "tools")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_agent_service_lifecycle(mock_openai_agents):
|
||||
"""Test service lifecycle methods."""
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.services.openai_agent.agent_service import OpenAIAgentService
|
||||
|
||||
service = OpenAIAgentService(
|
||||
name="Test Agent", instructions="Test instructions", api_key="test-key"
|
||||
)
|
||||
|
||||
# Test start
|
||||
start_frame = StartFrame()
|
||||
await service.start(start_frame)
|
||||
|
||||
# Test cancel
|
||||
cancel_frame = CancelFrame()
|
||||
await service.cancel(cancel_frame)
|
||||
|
||||
# Test stop
|
||||
end_frame = EndFrame()
|
||||
await service.stop(end_frame)
|
||||
|
||||
|
||||
def test_openai_agent_service_import_error():
|
||||
"""Test that import error is handled gracefully."""
|
||||
# Mock the import to fail
|
||||
with patch.dict("sys.modules", {"agents": None}):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# This should trigger the import error
|
||||
import importlib
|
||||
|
||||
import pipecat.services.openai_agent.agent_service
|
||||
|
||||
importlib.reload(pipecat.services.openai_agent.agent_service)
|
||||
|
||||
assert "Missing module" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -65,7 +65,7 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
ignore_start=False,
|
||||
start_metadata={"foo": "bar"},
|
||||
pipeline_params=PipelineParams(start_metadata={"foo": "bar"}),
|
||||
)
|
||||
assert "foo" in received_down[-1].metadata
|
||||
|
||||
@@ -196,10 +196,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal start_received
|
||||
start_received = True
|
||||
|
||||
@task.event_handler("on_pipeline_ended")
|
||||
async def on_pipeline_ended(task, frame: EndFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame: Frame):
|
||||
nonlocal end_received
|
||||
end_received = True
|
||||
end_received = isinstance(frame, EndFrame)
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
@@ -214,10 +214,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@task.event_handler("on_pipeline_stopped")
|
||||
async def on_pipeline_ended(task, frame: StopFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame: Frame):
|
||||
nonlocal stop_received
|
||||
stop_received = True
|
||||
stop_received = isinstance(frame, StopFrame)
|
||||
|
||||
await task.queue_frame(StopFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
@@ -441,10 +441,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def on_pipeline_started(task: PipelineTask, frame: StartFrame):
|
||||
await task.cancel()
|
||||
|
||||
@task.event_handler("on_pipeline_cancelled")
|
||||
async def on_pipeline_cancelled(task: PipelineTask, frame: CancelFrame):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
cancelled = isinstance(frame, CancelFrame)
|
||||
|
||||
try:
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
261
tests/test_run_inference.py
Normal file
261
tests/test_run_inference.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from anthropic import NOT_GIVEN
|
||||
from openai import NotGiven
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response."""
|
||||
# Create service with mocked client
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
]
|
||||
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
|
||||
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello! How can I help you today?"
|
||||
service._client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
service._client.chat.completions.create.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
messages=test_messages,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_client_exception():
|
||||
"""Test that exceptions from the client are propagated."""
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
|
||||
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.chat.completions.create.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for Anthropic."""
|
||||
# Create service with mocked client
|
||||
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
test_system = "You are a helpful assistant"
|
||||
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
|
||||
messages=test_messages, system=test_system, tools=[]
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Hello! How can I help you today?"
|
||||
service._client.messages.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(
|
||||
mock_context, enable_prompt_caching=False
|
||||
)
|
||||
service._client.messages.create.assert_called_once_with(
|
||||
model="claude-3-sonnet-20240229",
|
||||
messages=test_messages,
|
||||
system=test_system,
|
||||
max_tokens=8192,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_client_exception():
|
||||
"""Test that exceptions from the Anthropic client are propagated."""
|
||||
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
|
||||
messages=[], system="Test system", tools=[]
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.messages.create.side_effect = Exception("Anthropic API Error")
|
||||
|
||||
with pytest.raises(Exception, match="Anthropic API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for Google."""
|
||||
# Create service with mocked client
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
test_system = "You are a helpful assistant"
|
||||
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
|
||||
messages=test_messages, system_instruction=test_system, tools=NotGiven()
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [MagicMock()]
|
||||
mock_response.candidates[0].content = MagicMock()
|
||||
mock_response.candidates[0].content.parts = [MagicMock()]
|
||||
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
service._client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_client_exception():
|
||||
"""Test that exceptions from the Google client are propagated."""
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
|
||||
service._client = AsyncMock()
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
|
||||
messages=[], system_instruction="Test system", tools=NotGiven()
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(
|
||||
side_effect=Exception("Google API Error")
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Google API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
|
||||
# Create service and patch the session client method
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
# Setup mocks
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
|
||||
test_system = [{"text": "You are a helpful assistant"}]
|
||||
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
|
||||
messages=test_messages, system=test_system, tools=[], tool_choice=None
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock the client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
|
||||
}
|
||||
mock_client.converse.return_value = mock_response
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
async def mock_client_cm(*args, **kwargs):
|
||||
return mock_client
|
||||
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
# Execute
|
||||
result = await service.run_inference(mock_context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service.get_llm_adapter.assert_called_once()
|
||||
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
|
||||
mock_client.converse.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_client_exception():
|
||||
"""Test that exceptions from the AWS Bedrock client are propagated."""
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
mock_context = MagicMock(spec=LLMContext)
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
|
||||
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
|
||||
)
|
||||
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
|
||||
|
||||
# Mock AWS client to raise exception
|
||||
mock_client = AsyncMock()
|
||||
mock_client.converse.side_effect = Exception("Bedrock API Error")
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
with pytest.raises(Exception, match="Bedrock API Error"):
|
||||
await service.run_inference(mock_context)
|
||||
@@ -14,7 +14,7 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
@@ -238,7 +238,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
SleepFrame(),
|
||||
StartInterruptionFrame(), # User interrupts here
|
||||
InterruptionFrame(), # User interrupts here
|
||||
SleepFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="New"),
|
||||
@@ -252,7 +252,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame, # "Hello"
|
||||
TTSTextFrame, # "world!"
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionUpdateFrame, # First message (emitted due to interruption)
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame, # "New"
|
||||
|
||||
Reference in New Issue
Block a user