Compare commits
121 Commits
aleix/obse
...
aleix/queu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41695806e8 | ||
|
|
7280e390d9 | ||
|
|
4efc3f0a39 | ||
|
|
cb7e7a8aa3 | ||
|
|
9136402846 | ||
|
|
260fc76137 | ||
|
|
7cfb9a4d15 | ||
|
|
2089e0c974 | ||
|
|
9e0b4fe5d1 | ||
|
|
75ce632f84 | ||
|
|
efeb96c4e8 | ||
|
|
fb5438e9c2 | ||
|
|
7da9f66e1c | ||
|
|
9e16e3d614 | ||
|
|
84d040c6d0 | ||
|
|
f3e0beb8f1 | ||
|
|
e00a1196ef | ||
|
|
3867c0f8e7 | ||
|
|
cdf0953722 | ||
|
|
ed00f7d071 | ||
|
|
a3038afa02 | ||
|
|
f9ca0b8cc6 | ||
|
|
2920aa5af4 | ||
|
|
93c9cc4a0e | ||
|
|
b53f9235e4 | ||
|
|
1491462d15 | ||
|
|
c78f779800 | ||
|
|
b013e375fb | ||
|
|
52036138c1 | ||
|
|
4ba9a42861 | ||
|
|
27bff7a759 | ||
|
|
896f8d85f7 | ||
|
|
ed06cdd2c7 | ||
|
|
8473647269 | ||
|
|
5579145a06 | ||
|
|
35848d10b3 | ||
|
|
c7e223e85a | ||
|
|
885b2d1d2f | ||
|
|
73020be511 | ||
|
|
d388c057c0 | ||
|
|
c4d0f91a7f | ||
|
|
467233be04 | ||
|
|
2b02d08f4c | ||
|
|
9fe265ea64 | ||
|
|
cc1f4ba81c | ||
|
|
3784bdbd27 | ||
|
|
4ffdc3b77c | ||
|
|
38c9fa681a | ||
|
|
c477039954 | ||
|
|
d6ef3d64ac | ||
|
|
6938152db6 | ||
|
|
2154db07f0 | ||
|
|
5e0803479e | ||
|
|
3960c604a4 | ||
|
|
394648f1c9 | ||
|
|
da5c4953d5 | ||
|
|
2b7e1cb5b1 | ||
|
|
f182eafb40 | ||
|
|
9f7f42e885 | ||
|
|
9b8bce1914 | ||
|
|
96d05e12fc | ||
|
|
68c1069548 | ||
|
|
5b64613f65 | ||
|
|
1f9baefba8 | ||
|
|
0c255d2618 | ||
|
|
a38206de9c | ||
|
|
260f7c9b85 | ||
|
|
de294caed9 | ||
|
|
e40aa4f99a | ||
|
|
b1d413b9be | ||
|
|
8cbad070ad | ||
|
|
13569a5a5a | ||
|
|
d789334a60 | ||
|
|
7668b27fc0 | ||
|
|
6d30f441e8 | ||
|
|
a9e395b366 | ||
|
|
5e5626f04f | ||
|
|
d80aa5b44e | ||
|
|
80ef6dc4de | ||
|
|
458549f7df | ||
|
|
a8405649d0 | ||
|
|
ce1a72850b | ||
|
|
58de381746 | ||
|
|
bed2e894a2 | ||
|
|
b4de98cfb7 | ||
|
|
a4b9db9e07 | ||
|
|
664111a3c9 | ||
|
|
aa964847f3 | ||
|
|
fa5cac7e0a | ||
|
|
b2b01861b2 | ||
|
|
f014f718eb | ||
|
|
05ae8d3ffa | ||
|
|
88c9e08bd8 | ||
|
|
844f61dfea | ||
|
|
acb7d597cb | ||
|
|
2b18f60261 | ||
|
|
5b66133a6c | ||
|
|
0c5bc6a57a | ||
|
|
7981e00955 | ||
|
|
5e39c0cfeb | ||
|
|
a444701929 | ||
|
|
f6c1eb5d9d | ||
|
|
a1d46cb26b | ||
|
|
99ab148d88 | ||
|
|
d69fa5dba5 | ||
|
|
0d30b000af | ||
|
|
e7c0e742d2 | ||
|
|
2aff2dcca3 | ||
|
|
288f8865c8 | ||
|
|
8691870bcb | ||
|
|
e06146c237 | ||
|
|
c68e990cda | ||
|
|
4583905313 | ||
|
|
9cc498b1fa | ||
|
|
b3c5dc4045 | ||
|
|
3824da7261 | ||
|
|
855d567b1e | ||
|
|
b323a7bd88 | ||
|
|
fa011d0018 | ||
|
|
c510870736 | ||
|
|
e8783f6a33 |
57
CHANGELOG.md
57
CHANGELOG.md
@@ -5,16 +5,71 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.67] - 2025-05-07
|
||||
|
||||
### Added
|
||||
|
||||
- Added `DebugLogObserver` for detailed frame logging with configurable
|
||||
filtering by frame type and endpoint. This observer automatically extracts
|
||||
and formats all frame data fields for debug logging.
|
||||
|
||||
- `UserImageRequestFrame.video_source` field has been added to request an image
|
||||
from the desired video source.
|
||||
|
||||
- Added support for the AWS Nova Sonic speech-to-speech model with the new
|
||||
`AWSNovaSonicLLMService`.
|
||||
See https://docs.aws.amazon.com/nova/latest/userguide/speech.html.
|
||||
Note that it requires Python >= 3.12 and `pip install pipecat-ai[aws-nova-sonic]`.
|
||||
|
||||
- Added new AWS services `AWSBedrockLLMService` and `AWSTranscribeSTTService`.
|
||||
|
||||
- Added `on_active_speaker_changed` event handler to the `DailyTransport` class.
|
||||
|
||||
- Added `enable_ssml_parsing` and `enable_logging` to `InputParams` in
|
||||
`ElevenLabsTTSService`.
|
||||
|
||||
- Added support to `RimeHttpTTSService` for the `arcana` model.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `ElevenLabsTTSService` to use the beta websocket API
|
||||
(multi-stream-input). This new API supports context_ids and cancelling those
|
||||
contexts, which greatly improves interruption handling.
|
||||
|
||||
- Observers `on_push_frame()` now take a single argument `FramePushed` instead
|
||||
of multiple arguments.
|
||||
|
||||
- Updated the default voice for `DeepgramTTSService` to `aura-2-helena-en`.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `PollyTTSService` is now deprecated, use `AWSPollyTTSService` instead.
|
||||
|
||||
- Observer `on_push_frame(src, dst, frame, direction, timestamp)` is now
|
||||
deprecated, use `on_push_frame(data: FramePushed)` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `DailyTransport` issue that was causing issues when multiple audio or
|
||||
video sources where being captured.
|
||||
|
||||
- Fixed a `UltravoxSTTService` issue that would cause the service to generate
|
||||
all tokens as one word.
|
||||
|
||||
- Fixed a `PipelineTask` issue that would cause tasks to not be cancelled if
|
||||
task was cancelled from outside of Pipecat.
|
||||
|
||||
- Fixed a `TaskManager` that was causing dangling tasks to be reported.
|
||||
|
||||
- Fixed an issue that could cause data to be sent to the transports when they
|
||||
were still not ready.
|
||||
|
||||
- Remove custom audio tracks from `DailyTransport` before leaving.
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed `CanonicalMetricsService` as it's no longer maintained.
|
||||
|
||||
## [0.0.66] - 2025-05-02
|
||||
|
||||
### Added
|
||||
|
||||
24
README.md
24
README.md
@@ -49,18 +49,18 @@ You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
## 🧩 Available services
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
|
||||
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
| Category | Services |
|
||||
|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
|
||||
| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
|
||||
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ pipecat-ai[anthropic]
|
||||
pipecat-ai[assemblyai]
|
||||
pipecat-ai[aws]
|
||||
pipecat-ai[azure]
|
||||
pipecat-ai[canonical]
|
||||
pipecat-ai[cartesia]
|
||||
pipecat-ai[cerebras]
|
||||
pipecat-ai[deepseek]
|
||||
|
||||
161
examples/canonical-metrics/.gitignore
vendored
161
examples/canonical-metrics/.gitignore
vendored
@@ -1,161 +0,0 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
recordings/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
runpod.toml
|
||||
@@ -1,10 +0,0 @@
|
||||
FROM python:3.10-bullseye
|
||||
RUN mkdir /app
|
||||
COPY *.py /app/
|
||||
COPY requirements.txt /app/
|
||||
WORKDIR /app
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
CMD ["python3", "server.py"]
|
||||
@@ -1,66 +0,0 @@
|
||||
# Chatbot with canonical-metrics
|
||||
|
||||
This project implements a chatbot using a pipeline architecture that integrates audio processing, transcription, and a language model for conversational interactions. The chatbot operates within a daily communication environment, utilizing various services for text-to-speech and language model responses.
|
||||
|
||||
## Features
|
||||
|
||||
- **Audio Input and Output**: Captures microphone input and plays back audio responses.
|
||||
- **Voice Activity Detection**: Utilizes Silero VAD to manage audio input intelligently.
|
||||
- **Text-to-Speech**: Integrates ElevenLabs TTS service to convert text responses into audio.
|
||||
- **Language Model Interaction**: Uses OpenAI's GPT-4 model to generate responses based on user input.
|
||||
- **Transcription Services**: Captures and transcribes participant speech for analytics.
|
||||
- **Metrics Collection**: Sends audio data for analysis via Canonical Metrics Service.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.10+
|
||||
- `python-dotenv`
|
||||
- Additional libraries from the `pipecat` package.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Clone the repository.
|
||||
2. Install the required packages.
|
||||
3. Set up environment variables for API keys:
|
||||
- `OPENAI_API_KEY`
|
||||
- `ELEVENLABS_API_KEY`
|
||||
- `CANONICAL_API_KEY`
|
||||
- `CANONICAL_API_URL`
|
||||
4. Run the script.
|
||||
|
||||
## Usage
|
||||
|
||||
The chatbot introduces itself and engages in conversations, providing brief and creative responses. Designed for flexibility, it can support multiple languages with appropriate configuration.
|
||||
|
||||
## Events
|
||||
|
||||
- Participants joining or leaving the call are handled dynamically, adjusting the chatbot's behavior accordingly.
|
||||
|
||||
|
||||
ℹ️ The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded.
|
||||
|
||||
## Get started
|
||||
|
||||
```python
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
cp env.example .env # and add your credentials
|
||||
|
||||
```
|
||||
|
||||
## Run the server
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
Then, visit `http://localhost:7860/` in your browser to start a chatbot session.
|
||||
|
||||
## Build and test the Docker image
|
||||
|
||||
```
|
||||
docker build -t chatbot .
|
||||
docker run --env-file .env -p 7860:7860 chatbot
|
||||
```
|
||||
@@ -1,146 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.services.canonical.metrics import CanonicalMetricsService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Chatbot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
audio_in_enabled=True,
|
||||
video_out_enabled=False,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
transcription_enabled=True,
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# transcription_settings=DailyTranscriptionSettings(
|
||||
# language="es",
|
||||
# tier="nova",
|
||||
# model="2-general"
|
||||
# )
|
||||
),
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
#
|
||||
# English
|
||||
#
|
||||
voice_id="cgSgspJ2msm6clMCkdW9",
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# model="eleven_multilingual_v2",
|
||||
# voice_id="gD1IexrzCvsXPHUuT0s3",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
#
|
||||
# English
|
||||
#
|
||||
"content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself. Keep all your responses to 12 words or fewer.",
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# "content": "Eres Chatbot, un amigable y útil robot. Tu objetivo es demostrar tus capacidades de una manera breve. Tus respuestas se convertiran a audio así que nunca no debes incluir caracteres especiales. Contesta a lo que el usuario pregunte de una manera creativa, útil y breve. Empieza por presentarte a ti mismo.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
"""
|
||||
CanonicalMetrics uses AudioBufferProcessor under the hood to buffer the audio. On
|
||||
call completion, CanonicalMetrics will send the audio buffer to Canonical for
|
||||
analysis. Visit https://voice.canonical.chat to learn more.
|
||||
"""
|
||||
audio_buffer_processor = AudioBufferProcessor(num_channels=2)
|
||||
canonical = CanonicalMetricsService(
|
||||
audio_buffer_processor=audio_buffer_processor,
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("CANONICAL_API_KEY"),
|
||||
call_id=str(uuid.uuid4()),
|
||||
assistant="pipecat-chatbot",
|
||||
assistant_speaks_first=True,
|
||||
context=context,
|
||||
)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # microphone
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
canonical, # uploads audio buffer to Canonical AI for metrics
|
||||
audio_buffer_processor, # captures audio into a buffer
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await audio_buffer_processor.start_recording()
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
print(f"Participant left: {participant}")
|
||||
await task.cancel()
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
async def on_call_state_updated(transport, state):
|
||||
if state == "left":
|
||||
# Here we don't want to cancel, we just want to finish sending
|
||||
# whatever is queued, so we use an EndFrame().
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,6 +0,0 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
CANONICAL_API_KEY=can...
|
||||
CANONICAL_API_URL=
|
||||
@@ -1,5 +0,0 @@
|
||||
python-dotenv
|
||||
fastapi[all]
|
||||
uvicorn
|
||||
pipecat-ai[daily,openai,silero,elevenlabs,canonical]
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper
|
||||
|
||||
|
||||
async def configure(aiohttp_session: aiohttp.ClientSession):
|
||||
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=False, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Daily API Key (needed to create an owner token for the room)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
url = args.url or os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
key = args.apikey or os.getenv("DAILY_API_KEY")
|
||||
|
||||
if not url:
|
||||
raise Exception(
|
||||
"No Daily room specified. use the -u/--url option from the command line, or set DAILY_SAMPLE_ROOM_URL in your environment to specify a Daily room URL."
|
||||
)
|
||||
|
||||
if not key:
|
||||
raise Exception(
|
||||
"No Daily API key specified. use the -k/--apikey option from the command line, or set DAILY_API_KEY in your environment to specify a Daily API key, available from https://dashboard.daily.co/developers."
|
||||
)
|
||||
|
||||
daily_rest_helper = DailyRESTHelper(
|
||||
daily_api_key=key,
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
aiohttp_session=aiohttp_session,
|
||||
)
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in
|
||||
# the future.
|
||||
expiry_time: float = 60 * 60
|
||||
|
||||
token = await daily_rest_helper.get_token(url, expiry_time)
|
||||
|
||||
return (url, token)
|
||||
@@ -1,139 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
|
||||
|
||||
MAX_BOTS_PER_ROOM = 1
|
||||
|
||||
# Bot sub-process dict for status reporting and concurrency control
|
||||
bot_procs = {}
|
||||
|
||||
daily_helpers = {}
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
def cleanup():
|
||||
# Clean up function, just to be extra safe
|
||||
for entry in bot_procs.values():
|
||||
proc = entry[0]
|
||||
proc.terminate()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
aiohttp_session = aiohttp.ClientSession()
|
||||
daily_helpers["rest"] = DailyRESTHelper(
|
||||
daily_api_key=os.getenv("DAILY_API_KEY", ""),
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
aiohttp_session=aiohttp_session,
|
||||
)
|
||||
yield
|
||||
await aiohttp_session.close()
|
||||
cleanup()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def start_agent(request: Request):
|
||||
print(f"!!! Creating room")
|
||||
room = await daily_helpers["rest"].create_room(DailyRoomParams())
|
||||
print(f"!!! Room URL: {room.url}")
|
||||
# Ensure the room property is present
|
||||
if not room.url:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Missing 'room' property in request data. Cannot start agent without a target room!",
|
||||
)
|
||||
|
||||
# Check if there is already an existing process running in this room
|
||||
num_bots_in_room = sum(
|
||||
1 for proc in bot_procs.values() if proc[1] == room.url and proc[0].poll() is None
|
||||
)
|
||||
if num_bots_in_room >= MAX_BOTS_PER_ROOM:
|
||||
raise HTTPException(status_code=500, detail=f"Max bot limited reach for room: {room.url}")
|
||||
|
||||
# Get the token for the room
|
||||
token = await daily_helpers["rest"].get_token(room.url)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get token for room: {room.url}")
|
||||
|
||||
# Spawn a new agent, and join the user session
|
||||
# Note: this is mostly for demonstration purposes (refer to 'deployment' in README)
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[f"python3 -m bot -u {room.url} -t {token}"],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
)
|
||||
bot_procs[proc.pid] = (proc, room.url)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start subprocess: {e}")
|
||||
|
||||
return RedirectResponse(room.url)
|
||||
|
||||
|
||||
@app.get("/status/{pid}")
|
||||
def get_status(pid: int):
|
||||
# Look up the subprocess
|
||||
proc = bot_procs.get(pid)
|
||||
|
||||
# If the subprocess doesn't exist, return an error
|
||||
if not proc:
|
||||
raise HTTPException(status_code=404, detail=f"Bot with process id: {pid} not found")
|
||||
|
||||
# Check the status of the subprocess
|
||||
if proc[0].poll() is None:
|
||||
status = "running"
|
||||
else:
|
||||
status = "finished"
|
||||
|
||||
return JSONResponse({"bot_id": pid, "status": status})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
default_host = os.getenv("HOST", "0.0.0.0")
|
||||
default_port = int(os.getenv("FAST_API_PORT", "7860"))
|
||||
|
||||
parser = argparse.ArgumentParser(description="Daily Storyteller FastAPI server")
|
||||
parser.add_argument("--host", type=str, default=default_host, help="Host address")
|
||||
parser.add_argument("--port", type=int, default=default_port, help="Port number")
|
||||
parser.add_argument("--reload", action="store_true", help="Reload code on change")
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
reload=config.reload,
|
||||
)
|
||||
@@ -4,6 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
@@ -21,44 +22,23 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
# Check if we're in local development mode
|
||||
LOCAL_RUN = os.getenv("LOCAL_RUN")
|
||||
if LOCAL_RUN:
|
||||
import asyncio
|
||||
import webbrowser
|
||||
|
||||
try:
|
||||
from local_runner import configure
|
||||
except ImportError:
|
||||
logger.error("Could not import local_runner module. Local development mode may not work.")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Check if we're in local development mode
|
||||
LOCAL_RUN = os.getenv("LOCAL_RUN")
|
||||
|
||||
async def main(room_url: str, token: str):
|
||||
|
||||
async def main(transport: DailyTransport):
|
||||
"""Main pipeline setup and execution function.
|
||||
|
||||
Args:
|
||||
room_url: The Daily room URL
|
||||
token: The Daily room token
|
||||
transport: The DailyTransport object for the bot
|
||||
"""
|
||||
logger.debug("Starting bot in room: {}", room_url)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"bot",
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
logger.debug("Starting bot")
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22"
|
||||
api_key=os.getenv("CARTESIA_API_KEY"), voice_id="71a7ad14-091c-4e8e-a314-022ece01c121"
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
@@ -126,10 +106,25 @@ async def bot(args: DailySessionArguments):
|
||||
body: The configuration object from the request body
|
||||
session_id: The session ID for logging
|
||||
"""
|
||||
from pipecat.audio.filters.krisp_filter import KrispFilter
|
||||
|
||||
logger.info(f"Bot process initialized {args.room_url} {args.token}")
|
||||
|
||||
transport = DailyTransport(
|
||||
args.room_url,
|
||||
args.token,
|
||||
"Pipecat Bot",
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_in_filter=None if LOCAL_RUN else KrispFilter(),
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
await main(args.room_url, args.token)
|
||||
await main(transport)
|
||||
logger.info("Bot process completed")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in bot process: {str(e)}")
|
||||
@@ -137,18 +132,27 @@ async def bot(args: DailySessionArguments):
|
||||
|
||||
|
||||
# Local development functions
|
||||
async def local_main():
|
||||
async def local_daily():
|
||||
"""Function for local development testing."""
|
||||
from local_runner import configure
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
logger.warning("_")
|
||||
logger.warning("_")
|
||||
logger.warning(f"Talk to your voice agent here: {room_url}")
|
||||
logger.warning("_")
|
||||
logger.warning("_")
|
||||
webbrowser.open(room_url)
|
||||
await main(room_url, token)
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Pipecat Bot",
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
await main(transport)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in local development mode: {e}")
|
||||
|
||||
@@ -156,6 +160,6 @@ async def local_main():
|
||||
# Local development entry point
|
||||
if LOCAL_RUN and __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(local_main())
|
||||
asyncio.run(local_daily())
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run in local mode: {e}")
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
CARTESIA_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
# Local dev only
|
||||
DAILY_API_KEY=
|
||||
@@ -7,6 +7,7 @@
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
agent_name = "my-first-agent"
|
||||
image = "your-username/my-first-agent:0.1"
|
||||
image_credentials = "your-dockerhub-creds"
|
||||
secret_set = "my-first-agent-secrets"
|
||||
enable_krisp = true
|
||||
|
||||
[scaling]
|
||||
min_instances = 0
|
||||
|
||||
@@ -47,7 +47,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
live_options=LiveOptions(vad_events=True, utterance_end_ms="1000"),
|
||||
)
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en")
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en")
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -15,9 +14,9 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws.tts import PollyTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.aws.stt import AWSTranscribeSTTService
|
||||
from pipecat.services.aws.tts import AWSPollyTTSService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -37,17 +36,19 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = AWSTranscribeSTTService()
|
||||
|
||||
tts = PollyTTSService(
|
||||
api_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
voice_id="Amy",
|
||||
params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"),
|
||||
tts = AWSPollyTTSService(
|
||||
region="us-west-2", # only specific regions support generative TTS
|
||||
voice_id="Joanna",
|
||||
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -85,7 +86,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
139
examples/foundational/14r-function-calling-aws.py
Normal file
139
examples/foundational/14r-function-calling-aws.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.aws.stt import AWSTranscribeSTTService
|
||||
from pipecat.services.aws.tts import AWSPollyTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
stt = AWSTranscribeSTTService()
|
||||
|
||||
tts = AWSPollyTTSService(
|
||||
region="us-west-2", # only specific regions support generative TTS
|
||||
voice_id="Joanna",
|
||||
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
|
||||
)
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
|
||||
)
|
||||
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
267
examples/foundational/20e-persistent-context-aws-nova-sonic.py
Normal file
267
examples/foundational/20e-persistent-context-aws-nova-sonic.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws_nova_sonic.aws import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
BASE_FILENAME = "/tmp/pipecat_conversation_"
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(params: FunctionCallParams):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
full_pattern = f"{BASE_FILENAME}*.json"
|
||||
|
||||
# Use glob to find all matching files
|
||||
matching_files = glob.glob(full_pattern)
|
||||
logger.debug(f"matching files: {matching_files}")
|
||||
|
||||
await params.result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
# async def get_saved_conversation_filenames(
|
||||
# function_name, tool_call_id, args, llm, context, result_callback
|
||||
# ):
|
||||
# pattern = re.compile(re.escape(BASE_FILENAME) + "\\d{8}_\\d{6}\\.json$")
|
||||
# matching_files = []
|
||||
|
||||
# for filename in os.listdir("."):
|
||||
# if pattern.match(filename):
|
||||
# matching_files.append(filename)
|
||||
|
||||
# await result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
# remove the last few messages. in reverse order, they are:
|
||||
# - the in progress save tool call
|
||||
# - the invocation of the save tool call
|
||||
# - the user ask to save (which may encompass one or more messages)
|
||||
# the simplest thing to do is to pop messages until the last one is an assistant
|
||||
# response
|
||||
while messages and not (
|
||||
messages[-1].get("role") == "assistant" and "content" in messages[-1]
|
||||
):
|
||||
messages.pop()
|
||||
if messages: # we never expect this to be empty
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(messages, indent=4)}"
|
||||
)
|
||||
json.dump(messages, file, indent=2)
|
||||
await params.result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await params.result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(params: FunctionCallParams):
|
||||
async def _reset():
|
||||
filename = params.arguments["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
messages = json.load(file)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}",
|
||||
}
|
||||
)
|
||||
params.context.set_messages(messages)
|
||||
await params.llm.reset_conversation()
|
||||
await params.llm.trigger_assistant_response()
|
||||
except Exception as e:
|
||||
await params.result_callback({"success": False, "error": str(e)})
|
||||
|
||||
asyncio.create_task(_reset())
|
||||
|
||||
|
||||
get_current_weather_tool = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
save_conversation_tool = FunctionSchema(
|
||||
name="save_conversation",
|
||||
description="Save the current conversation. Use this function to persist the current conversation to external storage.",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
|
||||
get_saved_conversation_filenames_tool = FunctionSchema(
|
||||
name="get_saved_conversation_filenames",
|
||||
description="Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
|
||||
load_conversation_tool = FunctionSchema(
|
||||
name="load_conversation",
|
||||
description="Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
properties={
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
required=["filename"],
|
||||
)
|
||||
|
||||
tools = ToolsSchema(
|
||||
standard_tools=[
|
||||
get_current_weather_tool,
|
||||
save_conversation_tool,
|
||||
get_saved_conversation_filenames_tool,
|
||||
load_conversation_tool,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||
),
|
||||
)
|
||||
|
||||
# Specify initial system instruction.
|
||||
# HACK: note that, for now, we need to inject a special bit of text into this instruction to
|
||||
# allow the first assistant response to be programmatically triggered (which happens in the
|
||||
# on_client_connected handler, below)
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken dialog exchanging "
|
||||
"the transcripts of a natural real-time conversation. Keep your responses short, generally "
|
||||
"two or three sentences for chatty scenarios. "
|
||||
f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}"
|
||||
)
|
||||
|
||||
llm = AWSNovaSonicLLMService(
|
||||
secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region
|
||||
voice_id="tiffany", # matthew, tiffany, amy
|
||||
# you could choose to pass instruction here rather than via context
|
||||
# system_instruction=system_instruction,
|
||||
# you could choose to pass tools here rather than via context
|
||||
# tools=tools
|
||||
)
|
||||
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
# HACK: for now, we need this special way of triggering the first assistant response in AWS
|
||||
# Nova Sonic. Note that this trigger requires a special corresponding bit of text in the
|
||||
# system instruction. In the future, simply queueing the context frame should be sufficient.
|
||||
await llm.trigger_assistant_response()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
@@ -14,19 +14,26 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
EndFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.observers.loggers.llm_log_observer import LLMLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -34,7 +41,7 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class DebugObserver(BaseObserver):
|
||||
class CustomObserver(BaseObserver):
|
||||
"""Observer to log interruptions and bot speaking events to the console.
|
||||
|
||||
Logs all frame instances of:
|
||||
@@ -46,21 +53,20 @@ class DebugObserver(BaseObserver):
|
||||
Log format: [EVENT TYPE]: [source processor] → [destination processor] at [timestamp]s
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
src = data.source
|
||||
dst = data.destination
|
||||
frame = data.frame
|
||||
direction = data.direction
|
||||
timestamp = data.timestamp
|
||||
|
||||
# Convert timestamp to seconds for readability
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
|
||||
# Create direction arrow
|
||||
arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←"
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, StartInterruptionFrame) and isinstance(src, BaseOutputTransport):
|
||||
logger.info(f"⚡ INTERRUPTION START: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
logger.info(f"🤖 BOT START SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
@@ -119,7 +125,17 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
observers=[DebugObserver(), LLMLogObserver()],
|
||||
observers=[
|
||||
CustomObserver(),
|
||||
LLMLogObserver(),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION),
|
||||
UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE),
|
||||
EndFrame: None,
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
|
||||
173
examples/foundational/39-aws-nova-sonic.py
Normal file
173
examples/foundational/39-aws-nova-sonic.py
Normal file
@@ -0,0 +1,173 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws_nova_sonic import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Initialize the SmallWebRTCTransport with the connection
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_enabled=True,
|
||||
camera_in_enabled=False,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||
),
|
||||
)
|
||||
|
||||
# Specify initial system instruction.
|
||||
# HACK: note that, for now, we need to inject a special bit of text into this instruction to
|
||||
# allow the first assistant response to be programmatically triggered (which happens in the
|
||||
# on_client_connected handler, below)
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken dialog exchanging "
|
||||
"the transcripts of a natural real-time conversation. Keep your responses short, generally "
|
||||
"two or three sentences for chatty scenarios. "
|
||||
f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}"
|
||||
)
|
||||
|
||||
# Create the AWS Nova Sonic LLM service
|
||||
llm = AWSNovaSonicLLMService(
|
||||
secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region
|
||||
voice_id="tiffany", # matthew, tiffany, amy
|
||||
# you could choose to pass instruction here rather than via context
|
||||
# system_instruction=system_instruction
|
||||
# you could choose to pass tools here rather than via context
|
||||
# tools=tools
|
||||
)
|
||||
|
||||
# Register function for function calls
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
# Set up context and context management.
|
||||
# AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to
|
||||
# what's expected by Nova Sonic.
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a fun fact!",
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Build the pipeline
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
# Configure the pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Handle client connection event
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
# HACK: for now, we need this special way of triggering the first assistant response in AWS
|
||||
# Nova Sonic. Note that this trigger requires a special corresponding bit of text in the
|
||||
# system instruction. In the future, simply queueing the context frame should be sufficient.
|
||||
await llm.trigger_assistant_response()
|
||||
|
||||
# Handle client disconnection events
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
@@ -10,12 +10,16 @@ import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(override=True)
|
||||
|
||||
MAX_BOTS_PER_ROOM = 1
|
||||
|
||||
# Bot sub-process dict for status reporting and concurrency control
|
||||
|
||||
@@ -10,12 +10,16 @@ import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(override=True)
|
||||
|
||||
MAX_BOTS_PER_ROOM = 1
|
||||
|
||||
# Bot sub-process dict for status reporting and concurrency control
|
||||
|
||||
@@ -41,13 +41,13 @@ Website = "https://pipecat.ai"
|
||||
[project.optional-dependencies]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "assemblyai~=0.37.0" ]
|
||||
aws = [ "boto3~=1.37.16" ]
|
||||
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
canonical = [ "aiofiles~=24.1.0" ]
|
||||
cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.18.1" ]
|
||||
daily = [ "daily-python~=0.18.2" ]
|
||||
deepgram = [ "deepgram-sdk~=3.8.0" ]
|
||||
elevenlabs = [ "websockets~=13.1" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
@@ -97,6 +97,7 @@ where = ["src"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"pipecat" = ["py.typed"]
|
||||
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
40
src/pipecat/adapters/services/aws_nova_sonic_adapter.py
Normal file
40
src/pipecat/adapters/services/aws_nova_sonic_adapter.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
|
||||
class AWSNovaSonicLLMAdapter(BaseLLMAdapter):
|
||||
@staticmethod
|
||||
def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolSpec": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"inputSchema": {
|
||||
"json": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": function.properties,
|
||||
"required": function.required,
|
||||
}
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
"""Converts function schemas to AWS Nova Sonic function-calling format.
|
||||
|
||||
:return: AWS Nova Sonic formatted function call definition.
|
||||
"""
|
||||
|
||||
functions_schema = tools_schema.standard_tools
|
||||
return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema]
|
||||
38
src/pipecat/adapters/services/bedrock_adapter.py
Normal file
38
src/pipecat/adapters/services/bedrock_adapter.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
|
||||
class AWSBedrockLLMAdapter(BaseLLMAdapter):
|
||||
@staticmethod
|
||||
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolSpec": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": function.properties,
|
||||
"required": function.required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
"""Converts function schemas to Bedrock's function-calling format.
|
||||
|
||||
:return: Bedrock formatted function call definition.
|
||||
"""
|
||||
|
||||
functions_schema = tools_schema.standard_tools
|
||||
return [self._to_bedrock_function_format(func) for func in functions_schema]
|
||||
@@ -77,8 +77,8 @@ class Frame:
|
||||
|
||||
@dataclass
|
||||
class SystemFrame(Frame):
|
||||
"""System frames are frames that are not internally queued by any of the
|
||||
frame processors and should be processed immediately.
|
||||
"""A frame that takes higher priority than other frames. System frames are
|
||||
handled in order and are not affected by user interruptions.
|
||||
|
||||
"""
|
||||
|
||||
@@ -87,8 +87,9 @@ class SystemFrame(Frame):
|
||||
|
||||
@dataclass
|
||||
class DataFrame(Frame):
|
||||
"""Data frames are frames that will be processed in order and usually
|
||||
contain data such as LLM context, text, audio or images.
|
||||
"""A frame that is processed in order and usually contains data such as LLM
|
||||
context, text, audio or images. Data frames are cancelled by user
|
||||
interruptions.
|
||||
|
||||
"""
|
||||
|
||||
@@ -97,9 +98,9 @@ class DataFrame(Frame):
|
||||
|
||||
@dataclass
|
||||
class ControlFrame(Frame):
|
||||
"""Control frames are frames that, similar to data frames, will be processed
|
||||
in order and usually contain control information such as frames to update
|
||||
settings or to end the pipeline.
|
||||
"""A frame that, as data frames, is processed in order and usually contains
|
||||
control information such as update settings or to end the pipeline after
|
||||
everything is flushed. Control frames are cancelled by user interruptions.
|
||||
|
||||
"""
|
||||
|
||||
@@ -690,7 +691,7 @@ class FunctionCallResultFrame(SystemFrame):
|
||||
|
||||
@dataclass
|
||||
class STTMuteFrame(SystemFrame):
|
||||
"""System frame to mute/unmute the STT service."""
|
||||
"""A frame to mute/unmute the STT service."""
|
||||
|
||||
mute: bool
|
||||
|
||||
@@ -715,9 +716,10 @@ class UserImageRequestFrame(SystemFrame):
|
||||
context: Optional[Any] = None
|
||||
function_name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
video_source: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(user: {self.user_id}, function: {self.function_name}, request: {self.tool_call_id})"
|
||||
return f"{self.name}(user: {self.user_id}, video_source: {self.video_source}, function: {self.function_name}, request: {self.tool_call_id})"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -795,7 +797,7 @@ class EndFrame(ControlFrame):
|
||||
should be shut down. If the transport receives this frame, it will stop
|
||||
sending frames to its output channel(s) and close all its threads. Note,
|
||||
that this is a control frame, which means it will received in the order it
|
||||
was sent (unline system frames).
|
||||
was sent.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -5,9 +5,38 @@
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import TYPE_CHECKING
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FramePushed:
|
||||
"""Represents an event where a frame is pushed from one processor to another
|
||||
within the pipeline.
|
||||
|
||||
This data structure is typically used by observers to track the flow of
|
||||
frames through the pipeline for logging, debugging, or analytics purposes.
|
||||
|
||||
Attributes:
|
||||
source (FrameProcessor): The processor sending the frame.
|
||||
destination (FrameProcessor): The processor receiving the frame.
|
||||
frame (Frame): The frame being transferred.
|
||||
direction (FrameDirection): The direction of the transfer (e.g., downstream or upstream).
|
||||
timestamp (int): The time when the frame was pushed, based on the pipeline clock.
|
||||
|
||||
"""
|
||||
|
||||
source: "FrameProcessor"
|
||||
destination: "FrameProcessor"
|
||||
frame: Frame
|
||||
direction: "FrameDirection"
|
||||
timestamp: int
|
||||
|
||||
|
||||
class BaseObserver(ABC):
|
||||
@@ -19,26 +48,15 @@ class BaseObserver(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
"""Abstract method to handle the event when a frame is pushed from one
|
||||
processor to another.
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
"""Handle the event when a frame is pushed from one processor to another.
|
||||
|
||||
This method should be implemented by subclasses to define specific
|
||||
behavior (e.g., logging, monitoring, debugging) when a frame is
|
||||
transferred through the pipeline.
|
||||
|
||||
Args:
|
||||
src (FrameProcessor): The source frame processor that is sending the frame.
|
||||
dst (FrameProcessor): The destination frame processor that will receive the frame.
|
||||
frame (Frame): The frame being transferred between processors.
|
||||
direction (FrameDirection): The direction of the frame transfer.
|
||||
timestamp (int): The timestamp when the frame was pushed (based on the pipeline clock).
|
||||
|
||||
This method should be implemented by subclasses to define specific behavior
|
||||
when a frame is pushed.
|
||||
data (FramePushed): The event data containing details about the frame transfer.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
218
src/pipecat/observers/loggers/debug_log_observer.py
Normal file
218
src/pipecat/observers/loggers/debug_log_observer.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class FrameEndpoint(Enum):
|
||||
"""Specifies which endpoint (source or destination) to filter on."""
|
||||
|
||||
SOURCE = auto()
|
||||
DESTINATION = auto()
|
||||
|
||||
|
||||
class DebugLogObserver(BaseObserver):
|
||||
"""Observer that logs frame activity with detailed content to the console.
|
||||
|
||||
Automatically extracts and formats data from any frame type, making it useful
|
||||
for debugging pipeline behavior without needing frame-specific observers.
|
||||
|
||||
Args:
|
||||
frame_types: Optional tuple of frame types to log, or a dict with frame type
|
||||
filters. If None, logs all frame types.
|
||||
exclude_fields: Optional set of field names to exclude from logging.
|
||||
|
||||
Examples:
|
||||
Log all frames from all services:
|
||||
```python
|
||||
observers = DebugLogObserver()
|
||||
```
|
||||
|
||||
Log specific frame types from any source/destination:
|
||||
```python
|
||||
from pipecat.frames.frames import TranscriptionFrame, InterimTranscriptionFrame
|
||||
observers = DebugLogObserver(frame_types=(TranscriptionFrame, InterimTranscriptionFrame))
|
||||
```
|
||||
|
||||
Log frames with specific source/destination filters:
|
||||
```python
|
||||
from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
|
||||
from pipecat.transports.base_output_transport import BaseOutputTransport
|
||||
from pipecat.services.stt_service import STTService
|
||||
|
||||
observers = DebugLogObserver(frame_types={
|
||||
# Only log StartInterruptionFrame when source is BaseOutputTransport
|
||||
StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
|
||||
# Only log UserStartedSpeakingFrame when destination is STTService
|
||||
UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION),
|
||||
|
||||
# Log LLMTextFrame regardless of source or destination type
|
||||
LLMTextFrame: None
|
||||
})
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_types: Optional[
|
||||
Union[Tuple[Type[Frame], ...], Dict[Type[Frame], Optional[Tuple[Type, FrameEndpoint]]]]
|
||||
] = None,
|
||||
exclude_fields: Optional[Set[str]] = None,
|
||||
):
|
||||
"""Initialize the debug log observer.
|
||||
|
||||
Args:
|
||||
frame_types: Tuple of frame types to log, or a dict mapping frame types to
|
||||
filter configurations. Filter configs can be:
|
||||
- None to log all instances of the frame type
|
||||
- A tuple of (service_type, endpoint) to filter on a specific service
|
||||
and endpoint (SOURCE or DESTINATION)
|
||||
If None is provided instead of a tuple/dict, log all frames.
|
||||
exclude_fields: Set of field names to exclude from logging. If None, only binary
|
||||
data fields are excluded.
|
||||
"""
|
||||
# Process frame filters
|
||||
self.frame_filters = {}
|
||||
|
||||
if frame_types is not None:
|
||||
if isinstance(frame_types, tuple):
|
||||
# Tuple of frame types - log all instances
|
||||
self.frame_filters = {frame_type: None for frame_type in frame_types}
|
||||
else:
|
||||
# Dict of frame types with filters
|
||||
self.frame_filters = frame_types
|
||||
|
||||
# By default, exclude binary data fields that would clutter logs
|
||||
self.exclude_fields = (
|
||||
exclude_fields
|
||||
if exclude_fields is not None
|
||||
else {
|
||||
"audio", # Skip binary audio data
|
||||
"image", # Skip binary image data
|
||||
"images", # Skip lists of images
|
||||
}
|
||||
)
|
||||
|
||||
def _format_value(self, value):
|
||||
"""Format a value for logging.
|
||||
|
||||
Args:
|
||||
value: The value to format.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the value suitable for logging.
|
||||
"""
|
||||
if value is None:
|
||||
return "None"
|
||||
elif isinstance(value, str):
|
||||
return f"{value!r}"
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if len(value) == 0:
|
||||
return "[]"
|
||||
if isinstance(value[0], dict) and len(value) > 3:
|
||||
# For message lists, just show count
|
||||
return f"{len(value)} items"
|
||||
return str(value)
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
return f"{len(value)} bytes"
|
||||
elif hasattr(value, "get_messages_for_logging") and callable(
|
||||
getattr(value, "get_messages_for_logging")
|
||||
):
|
||||
# Special case for OpenAI context
|
||||
return f"{value.__class__.__name__} with messages: {value.get_messages_for_logging()}"
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def _should_log_frame(self, frame, src, dst):
|
||||
"""Determine if a frame should be logged based on filters.
|
||||
|
||||
Args:
|
||||
frame: The frame being processed
|
||||
src: The source component
|
||||
dst: The destination component
|
||||
|
||||
Returns:
|
||||
bool: True if the frame should be logged, False otherwise
|
||||
"""
|
||||
# If no filters, log all frames
|
||||
if not self.frame_filters:
|
||||
return True
|
||||
|
||||
# Check if this frame type is in our filters
|
||||
for frame_type, filter_config in self.frame_filters.items():
|
||||
if isinstance(frame, frame_type):
|
||||
# If filter is None, log all instances of this frame type
|
||||
if filter_config is None:
|
||||
return True
|
||||
|
||||
# Otherwise, check the specific filter
|
||||
service_type, endpoint = filter_config
|
||||
|
||||
if endpoint == FrameEndpoint.SOURCE:
|
||||
return isinstance(src, service_type)
|
||||
elif endpoint == FrameEndpoint.DESTINATION:
|
||||
return isinstance(dst, service_type)
|
||||
|
||||
return False
|
||||
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
"""Process a frame being pushed into the pipeline.
|
||||
|
||||
Logs frame details to the console with all relevant fields and values.
|
||||
|
||||
Args:
|
||||
data: Event data containing the frame, source, destination, direction, and timestamp.
|
||||
"""
|
||||
src = data.source
|
||||
dst = data.destination
|
||||
frame = data.frame
|
||||
direction = data.direction
|
||||
timestamp = data.timestamp
|
||||
|
||||
# Check if we should log this frame
|
||||
if not self._should_log_frame(frame, src, dst):
|
||||
return
|
||||
|
||||
# Format direction arrow
|
||||
arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←"
|
||||
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
class_name = frame.__class__.__name__
|
||||
|
||||
# Build frame representation
|
||||
frame_details = []
|
||||
|
||||
# If dataclass, extract fields
|
||||
if is_dataclass(frame):
|
||||
for field in fields(frame):
|
||||
if field.name in self.exclude_fields:
|
||||
continue
|
||||
|
||||
value = getattr(frame, field.name)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
formatted_value = self._format_value(value)
|
||||
frame_details.append(f"{field.name}: {formatted_value}")
|
||||
|
||||
# Format the message
|
||||
if frame_details:
|
||||
details = ", ".join(frame_details)
|
||||
message = f"{class_name} {details} at {time_sec:.2f}s"
|
||||
else:
|
||||
message = f"{class_name} at {time_sec:.2f}s"
|
||||
|
||||
# Log the message
|
||||
logger.debug(f"{src} {arrow} {dst}: {message}")
|
||||
@@ -7,7 +7,6 @@
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -15,9 +14,9 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
|
||||
@@ -38,14 +37,13 @@ class LLMLogObserver(BaseObserver):
|
||||
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
src = data.source
|
||||
dst = data.destination
|
||||
frame = data.frame
|
||||
direction = data.direction
|
||||
timestamp = data.timestamp
|
||||
|
||||
if not isinstance(src, LLMService) and not isinstance(dst, LLMService):
|
||||
return
|
||||
|
||||
|
||||
@@ -7,12 +7,10 @@
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.services.stt_service import STTService
|
||||
|
||||
|
||||
@@ -29,14 +27,11 @@ class TranscriptionLogObserver(BaseObserver):
|
||||
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
src = data.source
|
||||
frame = data.frame
|
||||
timestamp = data.timestamp
|
||||
|
||||
if not isinstance(src, STTService):
|
||||
return
|
||||
|
||||
|
||||
@@ -286,12 +286,7 @@ class PipelineTask(BaseTask):
|
||||
async def cancel(self):
|
||||
"""Stops the running pipeline immediately."""
|
||||
logger.debug(f"Canceling pipeline task {self}")
|
||||
# Make sure everything is cleaned up downstream. This is sent
|
||||
# out-of-band from the main streaming task which is what we want since
|
||||
# we want to cancel right away.
|
||||
await self._source.push_frame(CancelFrame())
|
||||
# Only cancel the push task. Everything else will be cancelled in run().
|
||||
await self._task_manager.cancel_task(self._process_push_task)
|
||||
await self._cancel()
|
||||
|
||||
async def run(self):
|
||||
"""Starts and manages the pipeline execution until completion or cancellation."""
|
||||
@@ -309,11 +304,17 @@ class PipelineTask(BaseTask):
|
||||
# well, because you get a CancelledError in every place you are
|
||||
# awaiting a task.
|
||||
pass
|
||||
await self._cancel_tasks()
|
||||
await self._cleanup(cleanup_pipeline)
|
||||
if self._check_dangling_tasks:
|
||||
self._print_dangling_tasks()
|
||||
self._finished = True
|
||||
finally:
|
||||
# It's possibe that we get an asyncio.CancelledError from the
|
||||
# outside, if so we need to make sure everything gets cancelled
|
||||
# properly.
|
||||
if cleanup_pipeline:
|
||||
await self._cancel()
|
||||
await self._cancel_tasks()
|
||||
await self._cleanup(cleanup_pipeline)
|
||||
if self._check_dangling_tasks:
|
||||
self._print_dangling_tasks()
|
||||
self._finished = True
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
"""Queue a single frame to be pushed down the pipeline.
|
||||
@@ -336,6 +337,14 @@ class PipelineTask(BaseTask):
|
||||
for frame in frames:
|
||||
await self.queue_frame(frame)
|
||||
|
||||
async def _cancel(self):
|
||||
# Make sure everything is cleaned up downstream. This is sent
|
||||
# out-of-band from the main streaming task which is what we want since
|
||||
# we want to cancel right away.
|
||||
await self._source.push_frame(CancelFrame())
|
||||
# Only cancel the push task. Everything else will be cancelled in run().
|
||||
await self._task_manager.cancel_task(self._process_push_task)
|
||||
|
||||
async def _create_tasks(self):
|
||||
self._process_up_task = self._task_manager.create_task(
|
||||
self._process_up_queue(), f"{self}::_process_up_queue"
|
||||
|
||||
@@ -5,13 +5,12 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import List
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.utils.asyncio import BaseTaskManager
|
||||
|
||||
|
||||
@@ -27,20 +26,6 @@ class Proxy:
|
||||
observer: BaseObserver
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObserverData:
|
||||
"""This is the data we receive from the main observer and that we put into a
|
||||
proxy queue for later processing.
|
||||
|
||||
"""
|
||||
|
||||
src: FrameProcessor
|
||||
dst: FrameProcessor
|
||||
frame: Frame
|
||||
direction: FrameDirection
|
||||
timestamp: int
|
||||
|
||||
|
||||
class TaskObserver(BaseObserver):
|
||||
"""This is a pipeline frame observer that is meant to be used as a proxy to
|
||||
the user provided observers. That is, this is the observer that should be
|
||||
@@ -68,20 +53,9 @@ class TaskObserver(BaseObserver):
|
||||
for proxy in self._proxies:
|
||||
await self._task_manager.cancel_task(proxy.task)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
for proxy in self._proxies:
|
||||
await proxy.queue.put(
|
||||
ObserverData(
|
||||
src=src, dst=dst, frame=frame, direction=direction, timestamp=timestamp
|
||||
)
|
||||
)
|
||||
await proxy.queue.put(data)
|
||||
|
||||
def _create_proxies(self, observers) -> List[Proxy]:
|
||||
proxies = []
|
||||
@@ -96,8 +70,26 @@ class TaskObserver(BaseObserver):
|
||||
return proxies
|
||||
|
||||
async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver):
|
||||
warning_reported = False
|
||||
while True:
|
||||
data = await queue.get()
|
||||
await observer.on_push_frame(
|
||||
data.src, data.dst, data.frame, data.direction, data.timestamp
|
||||
)
|
||||
|
||||
signature = inspect.signature(observer.on_push_frame)
|
||||
if len(signature.parameters) > 1:
|
||||
if not warning_reported:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Observer `on_push_frame(source, destination, frame, direction, timestamp)` is deprecated, us `on_push_frame(data: FramePushed)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
warning_reported = True
|
||||
await observer.on_push_frame(
|
||||
data.src, data.dst, data.frame, data.direction, data.timestamp
|
||||
)
|
||||
else:
|
||||
await observer.on_push_frame(data)
|
||||
|
||||
queue.task_done()
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Coroutine, Optional
|
||||
|
||||
@@ -21,6 +22,7 @@ from pipecat.frames.frames import (
|
||||
SystemFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
|
||||
from pipecat.observers.base_observer import FramePushed
|
||||
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
|
||||
from pipecat.utils.asyncio import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
@@ -31,6 +33,51 @@ class FrameDirection(Enum):
|
||||
UPSTREAM = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameProcessorQueueItem:
|
||||
frame: Frame
|
||||
direction: FrameDirection
|
||||
callback: Optional[Callable[["FrameProcessor", Frame, FrameDirection], Awaitable[None]]]
|
||||
|
||||
|
||||
class FrameProcessorQueue:
|
||||
def __init__(self):
|
||||
self._queue = asyncio.Queue()
|
||||
self._urgent_queue = asyncio.Queue()
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def put(self, item: FrameProcessorQueueItem):
|
||||
if isinstance(item.frame, SystemFrame):
|
||||
await self._urgent_queue.put(item)
|
||||
else:
|
||||
await self._queue.put(item)
|
||||
self._event.set()
|
||||
|
||||
async def get(self) -> FrameProcessorQueueItem:
|
||||
# Wait for an item in any of the queues.
|
||||
await self._event.wait()
|
||||
|
||||
if self._urgent_queue.empty():
|
||||
item = await self._queue.get()
|
||||
self._queue.task_done()
|
||||
else:
|
||||
item = await self._urgent_queue.get()
|
||||
self._urgent_queue.task_done()
|
||||
|
||||
# Clear the event only if all queues are empty.
|
||||
if self._queue.empty() and self._urgent_queue.empty():
|
||||
self._event.clear()
|
||||
|
||||
return item
|
||||
|
||||
def clear(self):
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
# Clear the event only if all queues are empty.
|
||||
if self._queue.empty() and self._urgent_queue.empty():
|
||||
self._event.clear()
|
||||
|
||||
|
||||
class FrameProcessor(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -68,18 +115,21 @@ class FrameProcessor(BaseObject):
|
||||
self._metrics = metrics or FrameProcessorMetrics()
|
||||
self._metrics.set_processor_name(self.name)
|
||||
|
||||
# Processors have an input queue. The input queue will be processed
|
||||
# immediately (default) or it will block if `pause_processing_frames()`
|
||||
# Processors receive frames on a streaming queue which are then
|
||||
# processed by a streaming task. This guarantees that all frames are
|
||||
# processed in the same task. By default, the streaming queue is
|
||||
# processed immediately but it may block if `pause_processing_frames()`
|
||||
# is called. To resume processing frames we need to call
|
||||
# `resume_processing_frames()` which will wake up the event.
|
||||
self.__should_block_frames = False
|
||||
self.__input_event = asyncio.Event()
|
||||
self.__input_frame_task: Optional[asyncio.Task] = None
|
||||
self.__streaming_event = asyncio.Event()
|
||||
self.__streaming_queue = FrameProcessorQueue()
|
||||
self.__streaming_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Every processor in Pipecat should only output frames from a single
|
||||
# task. This avoid problems like audio overlapping. System frames are the
|
||||
# exception to this rule. This create this task.
|
||||
self.__push_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__process_task: Optional[asyncio.Task] = None
|
||||
self.__process_urgent_queue = asyncio.Queue()
|
||||
self.__process_urgent_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -169,7 +219,8 @@ class FrameProcessor(BaseObject):
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self.__cancel_input_task()
|
||||
await self.__cancel_push_task()
|
||||
await self.__cancel_process_task()
|
||||
await self.__cancel_process_urgent_task()
|
||||
|
||||
def link(self, processor: "FrameProcessor"):
|
||||
self._next = processor
|
||||
@@ -214,7 +265,7 @@ class FrameProcessor(BaseObject):
|
||||
await self.process_frame(frame, direction)
|
||||
else:
|
||||
# We queue everything else.
|
||||
await self.__input_queue.put((frame, direction, callback))
|
||||
await self.__streaming_queue.put(FrameProcessorQueueItem(frame, direction, callback))
|
||||
|
||||
async def pause_processing_frames(self):
|
||||
logger.trace(f"{self}: pausing frame processing")
|
||||
@@ -222,7 +273,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
async def resume_processing_frames(self):
|
||||
logger.trace(f"{self}: resuming frame processing")
|
||||
self.__input_event.set()
|
||||
self.__streaming_event.set()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, StartFrame):
|
||||
@@ -249,19 +300,48 @@ class FrameProcessor(BaseObject):
|
||||
if not self._check_ready(frame):
|
||||
return
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
else:
|
||||
await self.__push_queue.put((frame, direction))
|
||||
try:
|
||||
timestamp = self._clock.get_time() if self._clock else 0
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
|
||||
if self._observer:
|
||||
data = FramePushed(
|
||||
source=self,
|
||||
destination=self._next,
|
||||
frame=frame,
|
||||
direction=direction,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._observer.on_push_frame(data)
|
||||
await self._next.queue_frame(frame, direction)
|
||||
elif direction == FrameDirection.UPSTREAM and self._prev:
|
||||
logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}")
|
||||
if self._observer:
|
||||
data = FramePushed(
|
||||
source=self,
|
||||
destination=self._prev,
|
||||
frame=frame,
|
||||
direction=direction,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._observer.on_push_frame(data)
|
||||
await self._prev.queue_frame(frame, direction)
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
raise
|
||||
|
||||
async def __start(self, frame: StartFrame):
|
||||
self.__create_process_task()
|
||||
self.__create_process_urgent_task()
|
||||
self.__create_input_task()
|
||||
self.__create_push_task()
|
||||
|
||||
async def __cancel(self, frame: CancelFrame):
|
||||
self._cancelling = True
|
||||
await self.__cancel_input_task()
|
||||
await self.__cancel_push_task()
|
||||
await self.__cancel_process_task()
|
||||
await self.__cancel_process_urgent_task()
|
||||
|
||||
#
|
||||
# Handle interruptions
|
||||
@@ -269,48 +349,32 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
async def _start_interruption(self):
|
||||
try:
|
||||
# Cancel the push frame task. This will stop pushing frames downstream.
|
||||
await self.__cancel_push_task()
|
||||
|
||||
# Cancel the input task. This will stop processing queued frames.
|
||||
# Cancel the streaming task.
|
||||
await self.__cancel_input_task()
|
||||
|
||||
# Cancel the task processing frames. We do not cancel the task that
|
||||
# is processing urgent frames.
|
||||
await self.__cancel_process_task()
|
||||
|
||||
# If there's an interruption we should not block frames anymore.
|
||||
self.__should_block_frames = False
|
||||
|
||||
# Clear the streaming queue, since we don't want to process its
|
||||
# frame anymore (except system and urgent frames).
|
||||
self.__streaming_queue.clear()
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
raise
|
||||
|
||||
# Create a new input queue and task.
|
||||
# Create a new tasks.
|
||||
self.__create_process_task()
|
||||
self.__create_input_task()
|
||||
|
||||
# Create a new output queue and task.
|
||||
self.__create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
# Nothing to do right now.
|
||||
pass
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
try:
|
||||
timestamp = self._clock.get_time() if self._clock else 0
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
if self._observer:
|
||||
await self._observer.on_push_frame(
|
||||
self, self._next, frame, direction, timestamp
|
||||
)
|
||||
await self._next.queue_frame(frame, direction)
|
||||
elif direction == FrameDirection.UPSTREAM and self._prev:
|
||||
logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}")
|
||||
if self._observer:
|
||||
await self._observer.on_push_frame(
|
||||
self, self._prev, frame, direction, timestamp
|
||||
)
|
||||
await self._prev.queue_frame(frame, direction)
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
raise
|
||||
|
||||
def _check_ready(self, frame: Frame):
|
||||
# If we are trying to push a frame but we still have no clock, it means
|
||||
# we didn't process a StartFrame.
|
||||
@@ -322,49 +386,60 @@ class FrameProcessor(BaseObject):
|
||||
return True
|
||||
|
||||
def __create_input_task(self):
|
||||
if not self.__input_frame_task:
|
||||
self.__should_block_frames = False
|
||||
self.__input_event.clear()
|
||||
self.__input_queue = asyncio.Queue()
|
||||
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
|
||||
if not self.__streaming_frame_task:
|
||||
self.__streaming_frame_task = self.create_task(self.__streaming_frame_task_handler())
|
||||
|
||||
async def __cancel_input_task(self):
|
||||
if self.__input_frame_task:
|
||||
await self.cancel_task(self.__input_frame_task)
|
||||
self.__input_frame_task = None
|
||||
if self.__streaming_frame_task:
|
||||
await self.cancel_task(self.__streaming_frame_task)
|
||||
self.__streaming_frame_task = None
|
||||
|
||||
async def __input_frame_task_handler(self):
|
||||
def __create_process_task(self):
|
||||
if not self.__process_task:
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__process_task = self.create_task(
|
||||
self.__process_task_handler(self.__process_queue)
|
||||
)
|
||||
|
||||
async def __cancel_process_task(self):
|
||||
if self.__process_task:
|
||||
await self.cancel_task(self.__process_task)
|
||||
self.__process_task = None
|
||||
|
||||
def __create_process_urgent_task(self):
|
||||
if not self.__process_urgent_task:
|
||||
self.__process_urgent_task = self.create_task(
|
||||
self.__process_task_handler(self.__process_urgent_queue)
|
||||
)
|
||||
|
||||
async def __cancel_process_urgent_task(self):
|
||||
if self.__process_urgent_task:
|
||||
await self.cancel_task(self.__process_urgent_task)
|
||||
self.__process_urgent_task = None
|
||||
|
||||
async def __streaming_frame_task_handler(self):
|
||||
while True:
|
||||
if self.__should_block_frames:
|
||||
logger.trace(f"{self}: frame processing paused")
|
||||
await self.__input_event.wait()
|
||||
self.__input_event.clear()
|
||||
await self.__streaming_event.wait()
|
||||
self.__streaming_event.clear()
|
||||
self.__should_block_frames = False
|
||||
logger.trace(f"{self}: frame processing resumed")
|
||||
|
||||
(frame, direction, callback) = await self.__input_queue.get()
|
||||
item = await self.__streaming_queue.get()
|
||||
|
||||
if isinstance(item.frame, SystemFrame):
|
||||
await self.__process_urgent_queue.put(item)
|
||||
else:
|
||||
await self.__process_queue.put(item)
|
||||
|
||||
async def __process_task_handler(self, queue: asyncio.Queue):
|
||||
while True:
|
||||
item = await queue.get()
|
||||
|
||||
# Process the frame.
|
||||
await self.process_frame(frame, direction)
|
||||
await self.process_frame(item.frame, item.direction)
|
||||
|
||||
# If this frame has an associated callback, call it now.
|
||||
if callback:
|
||||
await callback(self, frame, direction)
|
||||
|
||||
self.__input_queue.task_done()
|
||||
|
||||
def __create_push_task(self):
|
||||
if not self.__push_frame_task:
|
||||
self.__push_queue = asyncio.Queue()
|
||||
self.__push_frame_task = self.create_task(self.__push_frame_task_handler())
|
||||
|
||||
async def __cancel_push_task(self):
|
||||
if self.__push_frame_task:
|
||||
await self.cancel_task(self.__push_frame_task)
|
||||
self.__push_frame_task = None
|
||||
|
||||
async def __push_frame_task_handler(self):
|
||||
while True:
|
||||
(frame, direction) = await self.__push_queue.get()
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
self.__push_queue.task_done()
|
||||
if item.callback:
|
||||
await item.callback(self, item.frame, item.direction)
|
||||
|
||||
@@ -55,7 +55,7 @@ from pipecat.metrics.metrics import (
|
||||
TTFBMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -254,7 +254,7 @@ class RTVIBotReady(BaseModel):
|
||||
class RTVILLMFunctionCallMessageData(BaseModel):
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
args: Mapping[str, Any]
|
||||
|
||||
|
||||
class RTVILLMFunctionCallMessage(BaseModel):
|
||||
@@ -445,14 +445,7 @@ class RTVIObserver(BaseObserver):
|
||||
self._frames_seen = set()
|
||||
rtvi.set_errors_enabled(self._params.errors_enabled)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
"""Process a frame being pushed through the pipeline.
|
||||
|
||||
Args:
|
||||
@@ -462,6 +455,10 @@ class RTVIObserver(BaseObserver):
|
||||
direction: Direction of frame flow in pipeline
|
||||
timestamp: Time when frame was pushed
|
||||
"""
|
||||
src = data.source
|
||||
frame = data.frame
|
||||
direction = data.direction
|
||||
|
||||
# If we have already seen this frame, let's skip it.
|
||||
if frame.id in self._frames_seen:
|
||||
return
|
||||
@@ -703,7 +700,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
fn = RTVILLMFunctionCallMessageData(
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
arguments=params.arguments,
|
||||
args=params.arguments,
|
||||
)
|
||||
message = RTVILLMFunctionCallMessage(data=fn)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
|
||||
@@ -250,14 +250,24 @@ class AnthropicLLMService(LLMService):
|
||||
if hasattr(event.message.usage, "output_tokens")
|
||||
else 0
|
||||
)
|
||||
if hasattr(event.message.usage, "cache_creation_input_tokens"):
|
||||
cache_creation_input_tokens += (
|
||||
event.message.usage.cache_creation_input_tokens
|
||||
cache_creation_input_tokens += (
|
||||
event.message.usage.cache_creation_input_tokens
|
||||
if (
|
||||
hasattr(event.message.usage, "cache_creation_input_tokens")
|
||||
and event.message.usage.cache_creation_input_tokens is not None
|
||||
)
|
||||
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
|
||||
if hasattr(event.message.usage, "cache_read_input_tokens"):
|
||||
cache_read_input_tokens += event.message.usage.cache_read_input_tokens
|
||||
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
|
||||
else 0
|
||||
)
|
||||
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
|
||||
cache_read_input_tokens += (
|
||||
event.message.usage.cache_read_input_tokens
|
||||
if (
|
||||
hasattr(event.message.usage, "cache_read_input_tokens")
|
||||
and event.message.usage.cache_read_input_tokens is not None
|
||||
)
|
||||
else 0
|
||||
)
|
||||
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
|
||||
total_input_tokens = (
|
||||
prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
)
|
||||
|
||||
@@ -8,6 +8,8 @@ import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]")
|
||||
|
||||
785
src/pipecat/services/aws/llm.py
Normal file
785
src/pipecat/services/aws/llm.py
Normal file
@@ -0,0 +1,785 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import httpx
|
||||
from botocore.config import Config
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSBedrockContextAggregatorPair:
|
||||
_user: "AWSBedrockUserContextAggregator"
|
||||
_assistant: "AWSBedrockAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "AWSBedrockUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "AWSBedrockAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class AWSBedrockLLMContext(OpenAILLMContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
*,
|
||||
system: Optional[str] = None,
|
||||
):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.system = system
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext":
|
||||
logger.debug(f"Upgrading to AWS Bedrock: {obj}")
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext):
|
||||
obj.__class__ = AWSBedrockLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
else:
|
||||
obj._restructure_from_bedrock_messages()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
tools=openai_context.tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
self.set_llm_adapter(openai_context.get_llm_adapter())
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext":
|
||||
self = cls(messages=messages)
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext":
|
||||
context = cls()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
return context
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
# convert a message in AWS Bedrock format into one or more messages in OpenAI format
|
||||
def to_standard_messages(self, obj):
|
||||
"""Convert AWS Bedrock message format to standard structured format.
|
||||
|
||||
Handles text content and function calls for both user and assistant messages.
|
||||
|
||||
Args:
|
||||
obj: Message in AWS Bedrock format:
|
||||
{
|
||||
"role": "user/assistant",
|
||||
"content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}]
|
||||
}
|
||||
|
||||
Returns:
|
||||
List of messages in standard format:
|
||||
[
|
||||
{
|
||||
"role": "user/assistant/tool",
|
||||
"content": [{"type": "text", "text": str}]
|
||||
}
|
||||
]
|
||||
"""
|
||||
role = obj.get("role")
|
||||
content = obj.get("content")
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif "toolUse" in item:
|
||||
tool_use = item["toolUse"]
|
||||
tool_items.append(
|
||||
{
|
||||
"type": "function",
|
||||
"id": tool_use["toolUseId"],
|
||||
"function": {
|
||||
"name": tool_use["name"],
|
||||
"arguments": json.dumps(tool_use["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
if tool_items:
|
||||
messages.append({"role": role, "tool_calls": tool_items})
|
||||
return messages
|
||||
elif role == "user":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif "toolResult" in item:
|
||||
tool_result = item["toolResult"]
|
||||
# Extract content from toolResult
|
||||
result_content = ""
|
||||
if isinstance(tool_result["content"], list):
|
||||
for content_item in tool_result["content"]:
|
||||
if "text" in content_item:
|
||||
result_content = content_item["text"]
|
||||
elif "json" in content_item:
|
||||
result_content = json.dumps(content_item["json"])
|
||||
else:
|
||||
result_content = tool_result["content"]
|
||||
|
||||
tool_items.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_result["toolUseId"],
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
messages.extend(tool_items)
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to AWS Bedrock format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard format:
|
||||
{
|
||||
"role": "user/assistant/tool",
|
||||
"content": str | [{"type": "text", ...}],
|
||||
"tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}]
|
||||
}
|
||||
|
||||
Returns:
|
||||
Message in AWS Bedrock format:
|
||||
{
|
||||
"role": "user/assistant",
|
||||
"content": [
|
||||
{"text": str} |
|
||||
{"toolUse": {"toolUseId": str, "name": str, "input": dict}} |
|
||||
{"toolResult": {"toolUseId": str, "content": [...], "status": str}}
|
||||
]
|
||||
}
|
||||
"""
|
||||
if message["role"] == "tool":
|
||||
# Try to parse the content as JSON if it looks like JSON
|
||||
try:
|
||||
if message["content"].strip().startswith("{") and message[
|
||||
"content"
|
||||
].strip().endswith("}"):
|
||||
content_json = json.loads(message["content"])
|
||||
tool_result_content = [{"json": content_json}]
|
||||
else:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
except:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": tool_result_content,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
|
||||
# Handle text content
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
return {"role": message["role"], "content": [{"text": "(empty)"}]}
|
||||
else:
|
||||
return {"role": message["role"], "content": [{"text": content}]}
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if item.get("type", "") == "text":
|
||||
text_content = item["text"] if item["text"] != "" else "(empty)"
|
||||
new_content.append({"text": text_content})
|
||||
return {"role": message["role"], "content": new_content}
|
||||
|
||||
return message
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# Image should be the first content block in the message
|
||||
content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}]
|
||||
if text:
|
||||
content.append({"text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_message(self, message):
|
||||
try:
|
||||
if self.messages:
|
||||
# AWS Bedrock requires that roles alternate. If this message's
|
||||
# role is the same as the last message, we should add this
|
||||
# message's content to the last message.
|
||||
if self.messages[-1]["role"] == message["role"]:
|
||||
# if the last message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(self.messages[-1]["content"], str):
|
||||
self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}]
|
||||
# if this message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
# append the content of this message to the last message
|
||||
self.messages[-1]["content"].extend(message["content"])
|
||||
else:
|
||||
self.messages.append(message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def _restructure_from_bedrock_messages(self):
|
||||
"""Restructure messages in AWS Bedrock format by handling system
|
||||
messages, merging consecutive messages with the same role, and ensuring
|
||||
proper content formatting.
|
||||
|
||||
"""
|
||||
# Handle system message if present at the beginning
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
if len(self.messages) == 1:
|
||||
self.messages[0]["role"] = "user"
|
||||
else:
|
||||
system_content = self.messages.pop(0)["content"]
|
||||
if isinstance(system_content, str):
|
||||
system_content = [{"text": system_content}]
|
||||
|
||||
if self.system:
|
||||
if isinstance(self.system, str):
|
||||
self.system = [{"text": self.system}]
|
||||
self.system.extend(system_content)
|
||||
else:
|
||||
self.system = system_content
|
||||
|
||||
# Ensure content is properly formatted
|
||||
for msg in self.messages:
|
||||
if isinstance(msg["content"], str):
|
||||
msg["content"] = [{"text": msg["content"]}]
|
||||
elif not msg["content"]:
|
||||
msg["content"] = [{"text": "(empty)"}]
|
||||
elif isinstance(msg["content"], list):
|
||||
for idx, item in enumerate(msg["content"]):
|
||||
if isinstance(item, dict) and "text" in item and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
elif isinstance(item, str) and item == "":
|
||||
msg["content"][idx] = {"text": "(empty)"}
|
||||
|
||||
# Merge consecutive messages with the same role
|
||||
merged_messages = []
|
||||
for msg in self.messages:
|
||||
if merged_messages and merged_messages[-1]["role"] == msg["role"]:
|
||||
merged_messages[-1]["content"].extend(msg["content"])
|
||||
else:
|
||||
merged_messages.append(msg)
|
||||
|
||||
self.messages.clear()
|
||||
self.messages.extend(merged_messages)
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
|
||||
try:
|
||||
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our context.messages list. (For
|
||||
# compatibility with Open AI messages format.)
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
self.system = self.messages[0]["content"]
|
||||
self.messages.pop(0)
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(self.messages) - 1:
|
||||
current_message = self.messages[i]
|
||||
next_message = self.messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
self.messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in self.messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
if self.system:
|
||||
messages.insert(0, {"role": "system", "content": self.system})
|
||||
return messages
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("image"):
|
||||
item["source"]["bytes"] = "..."
|
||||
msgs.append(msg)
|
||||
return json.dumps(msgs)
|
||||
|
||||
|
||||
class AWSBedrockUserContextAggregator(LLMUserContextAggregator):
|
||||
pass
|
||||
|
||||
|
||||
class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
# Format tool use according to AWS Bedrock API
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": frame.tool_call_id,
|
||||
"name": frame.function_name,
|
||||
"input": frame.arguments if frame.arguments else {},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": frame.tool_call_id,
|
||||
"content": [{"text": "IN_PROGRESS"}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message["role"] == "user":
|
||||
for content in message["content"]:
|
||||
if (
|
||||
isinstance(content, dict)
|
||||
and content.get("toolResult")
|
||||
and content["toolResult"]["toolUseId"] == tool_call_id
|
||||
):
|
||||
content["toolResult"]["content"] = [{"text": result}]
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
|
||||
class AWSBedrockLLMService(LLMService):
|
||||
"""This class implements inference with AWS Bedrock models including Amazon
|
||||
Nova and Anthropic Claude.
|
||||
|
||||
Requires AWS credentials to be configured in the environment or through
|
||||
boto3 configuration.
|
||||
|
||||
"""
|
||||
|
||||
# Overriding the default adapter to use the Anthropic one.
|
||||
adapter_class = AWSBedrockLLMAdapter
|
||||
|
||||
class InputParams(BaseModel):
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0)
|
||||
stop_sequences: Optional[List[str]] = Field(default_factory=lambda: [])
|
||||
latency: Optional[str] = Field(default_factory=lambda: "standard")
|
||||
additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aws_access_key: Optional[str] = None,
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: str = "us-east-1",
|
||||
model: str,
|
||||
params: InputParams = InputParams(),
|
||||
client_config: Optional[Config] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize the AWS Bedrock client
|
||||
if not client_config:
|
||||
client_config = Config(
|
||||
connect_timeout=300, # 5 minutes
|
||||
read_timeout=300, # 5 minutes
|
||||
retries={"max_attempts": 3},
|
||||
)
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key,
|
||||
aws_secret_access_key=aws_secret_key,
|
||||
aws_session_token=aws_session_token,
|
||||
region_name=aws_region,
|
||||
)
|
||||
self._client = session.client(service_name="bedrock-runtime", config=client_config)
|
||||
|
||||
self.set_model_name(model)
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"latency": params.latency,
|
||||
"additional_model_request_fields": params.additional_model_request_fields
|
||||
if isinstance(params.additional_model_request_fields, dict)
|
||||
else {},
|
||||
}
|
||||
|
||||
logger.info(f"Using AWS Bedrock model: {model}")
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AWSBedrockContextAggregatorPair:
|
||||
"""Create an instance of AWSBedrockContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
AWSBedrockContextAggregatorPair: A pair of context aggregators, one
|
||||
for the user and one for the assistant, encapsulated in an
|
||||
AWSBedrockContextAggregatorPair.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AWSBedrockLLMContext.from_openai_context(context)
|
||||
|
||||
user = AWSBedrockUserContextAggregator(context, params=user_params)
|
||||
assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params)
|
||||
return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def _process_context(self, context: AWSBedrockLLMContext):
|
||||
# Usage tracking
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_tokens_estimate = 0
|
||||
cache_read_input_tokens = 0
|
||||
cache_creation_input_tokens = 0
|
||||
use_completion_tokens_estimate = False
|
||||
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Set up inference config
|
||||
inference_config = {
|
||||
"maxTokens": self._settings["max_tokens"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"topP": self._settings["top_p"],
|
||||
}
|
||||
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": self.model_name,
|
||||
"messages": context.messages,
|
||||
"inferenceConfig": inference_config,
|
||||
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
|
||||
}
|
||||
|
||||
# Add system message
|
||||
request_params["system"] = context.system
|
||||
|
||||
# Add tools if present
|
||||
if context.tools:
|
||||
tool_config = {"tools": context.tools}
|
||||
|
||||
# Add tool_choice if specified
|
||||
if context.tool_choice:
|
||||
if context.tool_choice == "auto":
|
||||
tool_config["toolChoice"] = {"auto": {}}
|
||||
elif context.tool_choice == "none":
|
||||
# Skip adding toolChoice for "none"
|
||||
pass
|
||||
elif (
|
||||
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
|
||||
):
|
||||
tool_config["toolChoice"] = {
|
||||
"tool": {"name": context.tool_choice["function"]["name"]}
|
||||
}
|
||||
|
||||
request_params["toolConfig"] = tool_config
|
||||
|
||||
# Add performance config if latency is specified
|
||||
if self._settings["latency"] in ["standard", "optimized"]:
|
||||
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
|
||||
|
||||
logger.debug(f"Calling AWS Bedrock model with: {request_params}")
|
||||
|
||||
# Call AWS Bedrock with streaming
|
||||
response = self._client.converse_stream(**request_params)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
# Process the streaming response
|
||||
tool_use_block = None
|
||||
json_accumulator = ""
|
||||
|
||||
for event in response["stream"]:
|
||||
# Handle text content
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
await self.push_frame(LLMTextFrame(delta["text"]))
|
||||
completion_tokens_estimate += self._estimate_tokens(delta["text"])
|
||||
elif "toolUse" in delta and "input" in delta["toolUse"]:
|
||||
# Handle partial JSON for tool use
|
||||
json_accumulator += delta["toolUse"]["input"]
|
||||
completion_tokens_estimate += self._estimate_tokens(
|
||||
delta["toolUse"]["input"]
|
||||
)
|
||||
|
||||
# Handle tool use start
|
||||
elif "contentBlockStart" in event:
|
||||
content_block_start = event["contentBlockStart"]["start"]
|
||||
if "toolUse" in content_block_start:
|
||||
tool_use_block = {
|
||||
"id": content_block_start["toolUse"].get("toolUseId", ""),
|
||||
"name": content_block_start["toolUse"].get("name", ""),
|
||||
}
|
||||
json_accumulator = ""
|
||||
|
||||
# Handle message completion with tool use
|
||||
elif "messageStop" in event and "stopReason" in event["messageStop"]:
|
||||
if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block:
|
||||
try:
|
||||
arguments = json.loads(json_accumulator) if json_accumulator else {}
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_use_block["id"],
|
||||
function_name=tool_use_block["name"],
|
||||
arguments=arguments,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse tool arguments: {json_accumulator}")
|
||||
|
||||
# Handle usage metrics if available
|
||||
if "metadata" in event and "usage" in event["metadata"]:
|
||||
usage = event["metadata"]["usage"]
|
||||
prompt_tokens += usage.get("inputTokens", 0)
|
||||
completion_tokens += usage.get("outputTokens", 0)
|
||||
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
|
||||
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
|
||||
# token estimate. The reraise the exception so all the processors running in this task
|
||||
# also get cancelled.
|
||||
use_completion_tokens_estimate = True
|
||||
raise
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
comp_tokens = (
|
||||
completion_tokens
|
||||
if not use_completion_tokens_estimate
|
||||
else completion_tokens_estimate
|
||||
)
|
||||
await self._report_usage_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=comp_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AWSBedrockLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
# This is only useful in very simple pipelines because it creates
|
||||
# a new context. Generally we want a context manager to catch
|
||||
# UserImageRawFrames coming through the pipeline and add them
|
||||
# to the context.
|
||||
context = AWSBedrockLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
return int(len(re.split(r"[^\w]+", text)) * 1.3)
|
||||
|
||||
async def _report_usage_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
):
|
||||
if prompt_tokens or completion_tokens:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
329
src/pipecat/services/aws/stt.py
Normal file
329
src/pipecat/services/aws/stt.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AWSTranscribeSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
region: Optional[str] = "us-east-1",
|
||||
sample_rate: int = 16000,
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"language": language,
|
||||
"media_encoding": "linear16", # AWS expects raw PCM
|
||||
"number_of_channels": 1,
|
||||
"show_speaker_label": False,
|
||||
"enable_channel_identification": False,
|
||||
}
|
||||
|
||||
# Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz
|
||||
if sample_rate not in [8000, 16000]:
|
||||
logger.warning(
|
||||
f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz."
|
||||
)
|
||||
self._settings["sample_rate"] = 16000
|
||||
|
||||
self._credentials = {
|
||||
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region": region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
self._ws_client = None
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self._connecting = False
|
||||
self._receive_task = None
|
||||
|
||||
def get_service_encoding(self, encoding: str) -> str:
|
||||
"""Convert internal encoding format to AWS Transcribe format."""
|
||||
encoding_map = {
|
||||
"linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM
|
||||
}
|
||||
return encoding_map.get(encoding, encoding)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the connection when the service starts."""
|
||||
await super().start(frame)
|
||||
logger.info("Starting AWS Transcribe service...")
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._connect()
|
||||
if self._ws_client and self._ws_client.open:
|
||||
logger.info("Successfully established WebSocket connection")
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
|
||||
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data and send to AWS Transcribe"""
|
||||
try:
|
||||
# Ensure WebSocket is connected
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.debug("WebSocket not connected, attempting to reconnect...")
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect: {e}")
|
||||
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
event_message = build_event_message(audio)
|
||||
|
||||
# Send the formatted event message
|
||||
try:
|
||||
await self._ws_client.send(event_message)
|
||||
# Start metrics after first chunk sent
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"Connection closed while sending: {e}")
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
await self._disconnect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_stt: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to AWS Transcribe with connection state management."""
|
||||
if self._ws_client and self._ws_client.open and self._receive_task:
|
||||
logger.debug(f"{self} Already connected")
|
||||
return
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._connecting:
|
||||
logger.debug(f"{self} Connection already in progress")
|
||||
return
|
||||
|
||||
try:
|
||||
self._connecting = True
|
||||
logger.debug(f"{self} Starting connection process...")
|
||||
|
||||
if self._ws_client:
|
||||
await self._disconnect()
|
||||
|
||||
language_code = self.language_to_service_language(
|
||||
Language(self._settings["language"])
|
||||
)
|
||||
if not language_code:
|
||||
raise ValueError(f"Unsupported language: {self._settings['language']}")
|
||||
|
||||
# Generate random websocket key
|
||||
websocket_key = "".join(
|
||||
random.choices(
|
||||
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
|
||||
)
|
||||
)
|
||||
|
||||
# Add required headers
|
||||
extra_headers = {
|
||||
"Origin": "https://localhost",
|
||||
"Sec-WebSocket-Key": websocket_key,
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Get presigned URL
|
||||
presigned_url = get_presigned_url(
|
||||
region=self._credentials["region"],
|
||||
credentials={
|
||||
"access_key": self._credentials["aws_access_key_id"],
|
||||
"secret_key": self._credentials["aws_secret_access_key"],
|
||||
"session_token": self._credentials["aws_session_token"],
|
||||
},
|
||||
language_code=language_code,
|
||||
media_encoding=self.get_service_encoding(
|
||||
self._settings["media_encoding"]
|
||||
), # Convert to AWS format
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
number_of_channels=self._settings["number_of_channels"],
|
||||
enable_partial_results_stabilization=True,
|
||||
partial_results_stability="high",
|
||||
show_speaker_label=self._settings["show_speaker_label"],
|
||||
enable_channel_identification=self._settings["enable_channel_identification"],
|
||||
)
|
||||
|
||||
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
|
||||
# Connect with the required headers and settings
|
||||
self._ws_client = await websockets.connect(
|
||||
presigned_url,
|
||||
extra_headers=extra_headers,
|
||||
subprotocols=["mqtt"],
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
compression=None,
|
||||
)
|
||||
|
||||
logger.debug(f"{self} WebSocket connected, starting receive task...")
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = self.create_task(self._receive_loop())
|
||||
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._connecting = False
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from AWS Transcribe."""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
try:
|
||||
if self._ws_client and self._ws_client.open:
|
||||
# Send end-stream message
|
||||
end_stream = {"message-type": "event", "event": "end"}
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
self._ws_client = None
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert internal language enum to AWS Transcribe language code."""
|
||||
language_map = {
|
||||
Language.EN: "en-US",
|
||||
Language.ES: "es-US",
|
||||
Language.FR: "fr-FR",
|
||||
Language.DE: "de-DE",
|
||||
Language.IT: "it-IT",
|
||||
Language.PT: "pt-BR",
|
||||
Language.JA: "ja-JP",
|
||||
Language.KO: "ko-KR",
|
||||
Language.ZH: "zh-CN",
|
||||
}
|
||||
return language_map.get(language)
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive and process messages from AWS Transcribe."""
|
||||
while True:
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.warning(f"{self} WebSocket closed in receive loop")
|
||||
break
|
||||
|
||||
try:
|
||||
response = await self._ws_client.recv()
|
||||
headers, payload = decode_event(response)
|
||||
|
||||
if headers.get(":message-type") == "event":
|
||||
# Process transcription results
|
||||
results = payload.get("Transcript", {}).get("Results", [])
|
||||
if results:
|
||||
result = results[0]
|
||||
alternatives = result.get("Alternatives", [])
|
||||
if alternatives:
|
||||
transcript = alternatives[0].get("Transcript", "")
|
||||
is_final = not result.get("IsPartial", True)
|
||||
|
||||
if transcript:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(
|
||||
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(
|
||||
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
break
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
@@ -26,9 +27,7 @@ try:
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -108,7 +107,7 @@ def language_to_aws_language(language: Language) -> Optional[str]:
|
||||
return language_map.get(language)
|
||||
|
||||
|
||||
class PollyTTSService(TTSService):
|
||||
class AWSPollyTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
engine: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
@@ -151,6 +150,24 @@ class PollyTTSService(TTSService):
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
# Get credentials from environment variables if not provided
|
||||
self._credentials = {
|
||||
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region": region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
# Validate that we have the required credentials
|
||||
if (
|
||||
not self._credentials["aws_access_key_id"]
|
||||
or not self._credentials["aws_secret_access_key"]
|
||||
):
|
||||
raise ValueError(
|
||||
"AWS credentials not found. Please provide them either through constructor parameters "
|
||||
"or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -165,18 +182,17 @@ class PollyTTSService(TTSService):
|
||||
|
||||
prosody_attrs = []
|
||||
# Prosody tags are only supported for standard and neural engines
|
||||
if self._settings["engine"] != "generative":
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["engine"] == "standard":
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
else:
|
||||
logger.warning("Prosody tags are not supported for generative engine. Ignoring.")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
ssml += text
|
||||
|
||||
@@ -187,6 +203,8 @@ class PollyTTSService(TTSService):
|
||||
|
||||
ssml += "</speak>"
|
||||
|
||||
logger.trace(f"{self} SSML: {ssml}")
|
||||
|
||||
return ssml
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -248,3 +266,17 @@ class PollyTTSService(TTSService):
|
||||
|
||||
finally:
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class PollyTTSService(AWSPollyTTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
261
src/pipecat/services/aws/utils.py
Normal file
261
src/pipecat/services/aws/utils.py
Normal file
@@ -0,0 +1,261 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import binascii
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import struct
|
||||
import urllib.parse
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def get_presigned_url(
|
||||
*,
|
||||
region: str,
|
||||
credentials: Dict[str, Optional[str]],
|
||||
language_code: str,
|
||||
media_encoding: str = "pcm",
|
||||
sample_rate: int = 16000,
|
||||
number_of_channels: int = 1,
|
||||
enable_partial_results_stabilization: bool = True,
|
||||
partial_results_stability: str = "high",
|
||||
vocabulary_name: Optional[str] = None,
|
||||
vocabulary_filter_name: Optional[str] = None,
|
||||
show_speaker_label: bool = False,
|
||||
enable_channel_identification: bool = False,
|
||||
) -> str:
|
||||
"""Create a presigned URL for AWS Transcribe streaming."""
|
||||
access_key = credentials.get("access_key")
|
||||
secret_key = credentials.get("secret_key")
|
||||
session_token = credentials.get("session_token")
|
||||
|
||||
if not access_key or not secret_key:
|
||||
raise ValueError("AWS credentials are required")
|
||||
|
||||
# Initialize the URL generator
|
||||
url_generator = AWSTranscribePresignedURL(
|
||||
access_key=access_key, secret_key=secret_key, session_token=session_token, region=region
|
||||
)
|
||||
|
||||
# Get the presigned URL
|
||||
return url_generator.get_request_url(
|
||||
sample_rate=sample_rate,
|
||||
language_code=language_code,
|
||||
media_encoding=media_encoding,
|
||||
vocabulary_name=vocabulary_name,
|
||||
vocabulary_filter_name=vocabulary_filter_name,
|
||||
show_speaker_label=show_speaker_label,
|
||||
enable_channel_identification=enable_channel_identification,
|
||||
number_of_channels=number_of_channels,
|
||||
enable_partial_results_stabilization=enable_partial_results_stabilization,
|
||||
partial_results_stability=partial_results_stability,
|
||||
)
|
||||
|
||||
|
||||
class AWSTranscribePresignedURL:
|
||||
def __init__(
|
||||
self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1"
|
||||
):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.session_token = session_token
|
||||
self.method = "GET"
|
||||
self.service = "transcribe"
|
||||
self.region = region
|
||||
self.endpoint = ""
|
||||
self.host = ""
|
||||
self.amz_date = ""
|
||||
self.datestamp = ""
|
||||
self.canonical_uri = "/stream-transcription-websocket"
|
||||
self.canonical_headers = ""
|
||||
self.signed_headers = "host"
|
||||
self.algorithm = "AWS4-HMAC-SHA256"
|
||||
self.credential_scope = ""
|
||||
self.canonical_querystring = ""
|
||||
self.payload_hash = ""
|
||||
self.canonical_request = ""
|
||||
self.string_to_sign = ""
|
||||
self.signature = ""
|
||||
self.request_url = ""
|
||||
|
||||
def get_request_url(
|
||||
self,
|
||||
sample_rate: int,
|
||||
language_code: str = "",
|
||||
media_encoding: str = "pcm",
|
||||
vocabulary_name: str = "",
|
||||
vocabulary_filter_name: str = "",
|
||||
show_speaker_label: bool = False,
|
||||
enable_channel_identification: bool = False,
|
||||
number_of_channels: int = 1,
|
||||
enable_partial_results_stabilization: bool = False,
|
||||
partial_results_stability: str = "",
|
||||
) -> str:
|
||||
self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443"
|
||||
self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443"
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
self.amz_date = now.strftime("%Y%m%dT%H%M%SZ")
|
||||
self.datestamp = now.strftime("%Y%m%d")
|
||||
self.canonical_headers = f"host:{self.host}\n"
|
||||
self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request"
|
||||
|
||||
# Create canonical querystring
|
||||
self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm
|
||||
self.canonical_querystring += (
|
||||
"&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope
|
||||
)
|
||||
self.canonical_querystring += "&X-Amz-Date=" + self.amz_date
|
||||
self.canonical_querystring += "&X-Amz-Expires=300"
|
||||
if self.session_token:
|
||||
self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote(
|
||||
self.session_token, safe=""
|
||||
)
|
||||
self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers
|
||||
|
||||
if enable_channel_identification:
|
||||
self.canonical_querystring += "&enable-channel-identification=true"
|
||||
if enable_partial_results_stabilization:
|
||||
self.canonical_querystring += "&enable-partial-results-stabilization=true"
|
||||
if language_code:
|
||||
self.canonical_querystring += "&language-code=" + language_code
|
||||
if media_encoding:
|
||||
self.canonical_querystring += "&media-encoding=" + media_encoding
|
||||
if number_of_channels > 1:
|
||||
self.canonical_querystring += "&number-of-channels=" + str(number_of_channels)
|
||||
if partial_results_stability:
|
||||
self.canonical_querystring += "&partial-results-stability=" + partial_results_stability
|
||||
if sample_rate:
|
||||
self.canonical_querystring += "&sample-rate=" + str(sample_rate)
|
||||
if show_speaker_label:
|
||||
self.canonical_querystring += "&show-speaker-label=true"
|
||||
if vocabulary_filter_name:
|
||||
self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name
|
||||
if vocabulary_name:
|
||||
self.canonical_querystring += "&vocabulary-name=" + vocabulary_name
|
||||
|
||||
# Create payload hash
|
||||
self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()
|
||||
|
||||
# Create canonical request
|
||||
self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}"
|
||||
|
||||
# Create string to sign
|
||||
credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request"
|
||||
string_to_sign = (
|
||||
f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n"
|
||||
+ hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest()
|
||||
)
|
||||
|
||||
# Calculate signature
|
||||
k_date = hmac.new(
|
||||
f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
|
||||
self.signature = hmac.new(
|
||||
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Add signature to query string
|
||||
self.canonical_querystring += "&X-Amz-Signature=" + self.signature
|
||||
|
||||
# Create request URL
|
||||
self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring
|
||||
return self.request_url
|
||||
|
||||
|
||||
def get_headers(header_name: str, header_value: str) -> bytearray:
|
||||
"""Build a header following AWS event stream format."""
|
||||
name = header_name.encode("utf-8")
|
||||
name_byte_length = bytes([len(name)])
|
||||
value_type = bytes([7]) # 7 represents a string
|
||||
value = header_value.encode("utf-8")
|
||||
value_byte_length = struct.pack(">H", len(value))
|
||||
|
||||
# Construct the header
|
||||
header_list = bytearray()
|
||||
header_list.extend(name_byte_length)
|
||||
header_list.extend(name)
|
||||
header_list.extend(value_type)
|
||||
header_list.extend(value_byte_length)
|
||||
header_list.extend(value)
|
||||
return header_list
|
||||
|
||||
|
||||
def build_event_message(payload: bytes) -> bytes:
|
||||
"""
|
||||
Build an event message for AWS Transcribe streaming.
|
||||
Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py
|
||||
"""
|
||||
# Build headers
|
||||
content_type_header = get_headers(":content-type", "application/octet-stream")
|
||||
event_type_header = get_headers(":event-type", "AudioEvent")
|
||||
message_type_header = get_headers(":message-type", "event")
|
||||
|
||||
headers = bytearray()
|
||||
headers.extend(content_type_header)
|
||||
headers.extend(event_type_header)
|
||||
headers.extend(message_type_header)
|
||||
|
||||
# Calculate total byte length and headers byte length
|
||||
# 16 accounts for 8 byte prelude, 2x 4 byte CRCs
|
||||
total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16)
|
||||
headers_byte_length = struct.pack(">I", len(headers))
|
||||
|
||||
# Build the prelude
|
||||
prelude = bytearray([0] * 8)
|
||||
prelude[:4] = total_byte_length
|
||||
prelude[4:] = headers_byte_length
|
||||
|
||||
# Calculate checksum for prelude
|
||||
prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF)
|
||||
|
||||
# Construct the message
|
||||
message_as_list = bytearray()
|
||||
message_as_list.extend(prelude)
|
||||
message_as_list.extend(prelude_crc)
|
||||
message_as_list.extend(headers)
|
||||
message_as_list.extend(payload)
|
||||
|
||||
# Calculate checksum for message
|
||||
message = bytes(message_as_list)
|
||||
message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF)
|
||||
|
||||
# Add message checksum
|
||||
message_as_list.extend(message_crc)
|
||||
|
||||
return bytes(message_as_list)
|
||||
|
||||
|
||||
def decode_event(message):
|
||||
# Extract the prelude, headers, payload and CRC
|
||||
prelude = message[:8]
|
||||
total_length, headers_length = struct.unpack(">II", prelude)
|
||||
prelude_crc = struct.unpack(">I", message[8:12])[0]
|
||||
headers = message[12 : 12 + headers_length]
|
||||
payload = message[12 + headers_length : -4]
|
||||
message_crc = struct.unpack(">I", message[-4:])[0]
|
||||
|
||||
# Check the CRCs
|
||||
assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed"
|
||||
assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed"
|
||||
|
||||
# Parse the headers
|
||||
headers_dict = {}
|
||||
while headers:
|
||||
name_len = headers[0]
|
||||
name = headers[1 : 1 + name_len].decode("utf-8")
|
||||
value_type = headers[1 + name_len]
|
||||
value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0]
|
||||
value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8")
|
||||
headers_dict[name] = value
|
||||
headers = headers[4 + name_len + value_len :]
|
||||
|
||||
return headers_dict, json.loads(payload)
|
||||
1
src/pipecat/services/aws_nova_sonic/__init__.py
Normal file
1
src/pipecat/services/aws_nova_sonic/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .aws import AWSNovaSonicLLMService
|
||||
997
src/pipecat/services/aws_nova_sonic/aws.py
Normal file
997
src/pipecat/services/aws_nova_sonic/aws.py
Normal file
@@ -0,0 +1,997 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from importlib.resources import files
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws_nova_sonic.context import (
|
||||
AWSNovaSonicAssistantContextAggregator,
|
||||
AWSNovaSonicContextAggregatorPair,
|
||||
AWSNovaSonicLLMContext,
|
||||
AWSNovaSonicUserContextAggregator,
|
||||
Role,
|
||||
)
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from aws_sdk_bedrock_runtime.client import (
|
||||
BedrockRuntimeClient,
|
||||
InvokeModelWithBidirectionalStreamOperationInput,
|
||||
)
|
||||
from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
|
||||
from aws_sdk_bedrock_runtime.models import (
|
||||
BidirectionalInputPayloadPart,
|
||||
InvokeModelWithBidirectionalStreamInput,
|
||||
InvokeModelWithBidirectionalStreamInputChunk,
|
||||
InvokeModelWithBidirectionalStreamOperationOutput,
|
||||
InvokeModelWithBidirectionalStreamOutput,
|
||||
)
|
||||
from smithy_aws_core.credentials_resolvers.static import StaticCredentialsResolver
|
||||
from smithy_aws_core.identity import AWSCredentialsIdentity
|
||||
from smithy_core.aio.eventstream import DuplexEventStream
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AWS services, you need to `pip install pipecat-ai[aws-nova-sonic]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AWSNovaSonicUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
AUDIO = "AUDIO"
|
||||
TEXT = "TEXT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
class TextStage(Enum):
|
||||
FINAL = "FINAL" # what has been said
|
||||
SPECULATIVE = "SPECULATIVE" # what's planned to be said
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurrentContent:
|
||||
type: ContentType
|
||||
role: Role
|
||||
text_stage: TextStage # None if not text
|
||||
text_content: str # starts as None, then fills in if text
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"CurrentContent(\n"
|
||||
f" type={self.type.name},\n"
|
||||
f" role={self.role.name},\n"
|
||||
f" text_stage={self.text_stage.name if self.text_stage else 'None'}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class Params(BaseModel):
|
||||
# Audio input
|
||||
input_sample_rate: Optional[int] = Field(default=16000)
|
||||
input_sample_size: Optional[int] = Field(default=16)
|
||||
input_channel_count: Optional[int] = Field(default=1)
|
||||
|
||||
# Audio output
|
||||
output_sample_rate: Optional[int] = Field(default=24000)
|
||||
output_sample_size: Optional[int] = Field(default=16)
|
||||
output_channel_count: Optional[int] = Field(default=1)
|
||||
|
||||
# Inference
|
||||
max_tokens: Optional[int] = Field(default=1024)
|
||||
top_p: Optional[float] = Field(default=0.9)
|
||||
temperature: Optional[float] = Field(default=0.7)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMService(LLMService):
|
||||
# Override the default adapter to use the AWSNovaSonicLLMAdapter one
|
||||
adapter_class = AWSNovaSonicLLMAdapter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret_access_key: str,
|
||||
access_key_id: str,
|
||||
region: str,
|
||||
model: str = "amazon.nova-sonic-v1:0",
|
||||
voice_id: str = "matthew", # matthew, tiffany, amy
|
||||
params: Params = Params(),
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[ToolsSchema] = None,
|
||||
send_transcription_frames: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._secret_access_key = secret_access_key
|
||||
self._access_key_id = access_key_id
|
||||
self._region = region
|
||||
self._model = model
|
||||
self._client: Optional[BedrockRuntimeClient] = None
|
||||
self._voice_id = voice_id
|
||||
self._params = params
|
||||
self._system_instruction = system_instruction
|
||||
self._tools = tools
|
||||
self._send_transcription_frames = send_transcription_frames
|
||||
self._context: Optional[AWSNovaSonicLLMContext] = None
|
||||
self._stream: Optional[
|
||||
DuplexEventStream[
|
||||
InvokeModelWithBidirectionalStreamInput,
|
||||
InvokeModelWithBidirectionalStreamOutput,
|
||||
InvokeModelWithBidirectionalStreamOperationOutput,
|
||||
]
|
||||
] = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._prompt_name: Optional[str] = None
|
||||
self._input_audio_content_name: Optional[str] = None
|
||||
self._content_being_received: Optional[CurrentContent] = None
|
||||
self._assistant_is_responding = False
|
||||
self._ready_to_send_context = False
|
||||
self._handling_bot_stopped_speaking = False
|
||||
self._triggering_assistant_response = False
|
||||
self._assistant_response_trigger_audio: Optional[bytes] = (
|
||||
None # Not cleared on _disconnect()
|
||||
)
|
||||
self._disconnecting = False
|
||||
self._connected_time: Optional[float] = None
|
||||
self._wants_connection = False
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._wants_connection = True
|
||||
await self._start_connecting()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._wants_connection = False
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._wants_connection = False
|
||||
await self._disconnect()
|
||||
|
||||
#
|
||||
# conversation resetting
|
||||
#
|
||||
|
||||
async def reset_conversation(self):
|
||||
logger.debug("Resetting conversation")
|
||||
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=False)
|
||||
|
||||
# Carry over previous context through disconnect
|
||||
context = self._context
|
||||
await self._disconnect()
|
||||
self._context = context
|
||||
|
||||
await self._start_connecting()
|
||||
|
||||
#
|
||||
# frame processing
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
await self._handle_context(frame.context)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._handle_input_audio_frame(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=True)
|
||||
elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_context(self, context: OpenAILLMContext):
|
||||
if not self._context:
|
||||
# We got our initial context - try to finish connecting
|
||||
self._context = AWSNovaSonicLLMContext.upgrade_to_nova_sonic(
|
||||
context, self._system_instruction
|
||||
)
|
||||
await self._finish_connecting_if_context_available()
|
||||
|
||||
async def _handle_input_audio_frame(self, frame: InputAudioRawFrame):
|
||||
# Wait until we're done sending the assistant response trigger audio before sending audio
|
||||
# from the user's mic
|
||||
if self._triggering_assistant_response:
|
||||
return
|
||||
|
||||
await self._send_user_audio_event(frame.audio)
|
||||
|
||||
async def _handle_bot_stopped_speaking(self, delay_to_catch_trailing_assistant_text: bool):
|
||||
# Protect against back-to-back BotStoppedSpeaking calls, which I've observed
|
||||
if self._handling_bot_stopped_speaking:
|
||||
return
|
||||
self._handling_bot_stopped_speaking = True
|
||||
|
||||
async def finalize_assistant_response():
|
||||
if self._assistant_is_responding:
|
||||
# Consider the assistant finished with their response (possibly after a short delay,
|
||||
# to allow for any trailing FINAL assistant text block to come in that need to make
|
||||
# it into context).
|
||||
#
|
||||
# TODO: ideally we could base this solely on the LLM output events, but I couldn't
|
||||
# figure out a reliable way to determine when we've gotten our last FINAL text block
|
||||
# after the LLM is done talking.
|
||||
#
|
||||
# First I looked at stopReason, but it doesn't seem like the last FINAL text block
|
||||
# is reliably marked END_TURN (sometimes the *first* one is, but not the last...
|
||||
# bug?)
|
||||
#
|
||||
# Then I considered schemes where we tally or match up SPECULATIVE text blocks with
|
||||
# FINAL text blocks to know how many or which FINAL blocks to expect, but user
|
||||
# interruptions throw a wrench in these schemes: depending on the exact timing of
|
||||
# the interruption, we should or shouldn't expect some FINAL blocks.
|
||||
if delay_to_catch_trailing_assistant_text:
|
||||
# This delay length is a balancing act between "catching" trailing assistant
|
||||
# text that is quite delayed but not waiting so long that user text comes in
|
||||
# first and results in a bit of context message order scrambling.
|
||||
await asyncio.sleep(1.25)
|
||||
self._assistant_is_responding = False
|
||||
await self._report_assistant_response_ended()
|
||||
|
||||
self._handling_bot_stopped_speaking = False
|
||||
|
||||
# Finalize the assistant response, either now or after a delay
|
||||
if delay_to_catch_trailing_assistant_text:
|
||||
self.create_task(finalize_assistant_response())
|
||||
else:
|
||||
await finalize_assistant_response()
|
||||
|
||||
async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame):
|
||||
result = frame.result_frame
|
||||
await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result)
|
||||
|
||||
#
|
||||
# LLM communication: lifecycle
|
||||
#
|
||||
|
||||
async def _start_connecting(self):
|
||||
try:
|
||||
logger.info("Connecting...")
|
||||
|
||||
if self._client:
|
||||
# Here we assume that if we have a client we are connected or connecting
|
||||
return
|
||||
|
||||
# Set IDs for the connection
|
||||
self._prompt_name = str(uuid.uuid4())
|
||||
self._input_audio_content_name = str(uuid.uuid4())
|
||||
|
||||
# Create the client
|
||||
self._client = self._create_client()
|
||||
|
||||
# Start the bidirectional stream
|
||||
self._stream = await self._client.invoke_model_with_bidirectional_stream(
|
||||
InvokeModelWithBidirectionalStreamOperationInput(model_id=self._model)
|
||||
)
|
||||
|
||||
# Send session start event
|
||||
await self._send_session_start_event()
|
||||
|
||||
# Finish connecting
|
||||
self._ready_to_send_context = True
|
||||
await self._finish_connecting_if_context_available()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._disconnect()
|
||||
|
||||
async def _finish_connecting_if_context_available(self):
|
||||
# We can only finish connecting once we've gotten our initial context and we're ready to
|
||||
# send it
|
||||
if not (self._context and self._ready_to_send_context):
|
||||
return
|
||||
|
||||
logger.info("Finishing connecting (setting up session)...")
|
||||
|
||||
# Read context
|
||||
history = self._context.get_messages_for_initializing_history()
|
||||
|
||||
# Send prompt start event, specifying tools.
|
||||
# Tools from context take priority over self._tools.
|
||||
tools = (
|
||||
self._context.tools
|
||||
if self._context.tools
|
||||
else self.get_llm_adapter().from_standard_tools(self._tools)
|
||||
)
|
||||
logger.debug(f"Using tools: {tools}")
|
||||
await self._send_prompt_start_event(tools)
|
||||
|
||||
# Send system instruction.
|
||||
# Instruction from context takes priority over self._system_instruction.
|
||||
# (NOTE: this prioritizing occurred automatically behind the scenes: the context was
|
||||
# initialized with self._system_instruction and then updated itself from its messages when
|
||||
# get_messages_for_initializing_history() was called).
|
||||
logger.debug(f"Using system instruction: {history.system_instruction}")
|
||||
if history.system_instruction:
|
||||
await self._send_text_event(text=history.system_instruction, role=Role.SYSTEM)
|
||||
|
||||
# Send conversation history
|
||||
for message in history.messages:
|
||||
await self._send_text_event(text=message.text, role=message.role)
|
||||
|
||||
# Start audio input
|
||||
await self._send_audio_input_start_event()
|
||||
|
||||
# Start receiving events
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
# Record finished connecting time (must be done before sending assistant response trigger)
|
||||
self._connected_time = time.time()
|
||||
|
||||
logger.info("Finished connecting")
|
||||
|
||||
# If we need to, send assistant response trigger (depends on self._connected_time)
|
||||
if self._triggering_assistant_response:
|
||||
await self._send_assistant_response_trigger()
|
||||
self._triggering_assistant_response = False
|
||||
|
||||
async def _disconnect(self):
|
||||
try:
|
||||
logger.info("Disconnecting...")
|
||||
|
||||
# NOTE: see explanation of HACK, below
|
||||
self._disconnecting = True
|
||||
|
||||
# Clean up client
|
||||
if self._client:
|
||||
await self._send_session_end_events()
|
||||
self._client = None
|
||||
|
||||
# Clean up stream
|
||||
if self._stream:
|
||||
await self._stream.input_stream.close()
|
||||
self._stream = None
|
||||
|
||||
# NOTE: see explanation of HACK, below
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Clean up receive task
|
||||
# HACK: we should ideally be able to cancel the receive task before stopping the input
|
||||
# stream, above (meaning we wouldn't need self._disconnecting). But for some reason if
|
||||
# we don't close the input stream and wait a second first, we're getting an error a lot
|
||||
# like this one: https://github.com/awslabs/amazon-transcribe-streaming-sdk/issues/61.
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
|
||||
# Reset remaining connection-specific state
|
||||
self._prompt_name = None
|
||||
self._input_audio_content_name = None
|
||||
self._content_being_received = None
|
||||
self._assistant_is_responding = False
|
||||
self._ready_to_send_context = False
|
||||
self._handling_bot_stopped_speaking = False
|
||||
self._triggering_assistant_response = False
|
||||
self._disconnecting = False
|
||||
self._connected_time = None
|
||||
|
||||
logger.info("Finished disconnecting")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
def _create_client(self) -> BedrockRuntimeClient:
|
||||
config = Config(
|
||||
endpoint_uri=f"https://bedrock-runtime.{self._region}.amazonaws.com",
|
||||
region=self._region,
|
||||
aws_credentials_identity_resolver=StaticCredentialsResolver(
|
||||
credentials=AWSCredentialsIdentity(
|
||||
access_key_id=self._access_key_id, secret_access_key=self._secret_access_key
|
||||
)
|
||||
),
|
||||
http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
|
||||
http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
|
||||
)
|
||||
return BedrockRuntimeClient(config=config)
|
||||
|
||||
#
|
||||
# LLM communication: input events (pipecat -> LLM)
|
||||
#
|
||||
|
||||
async def _send_session_start_event(self):
|
||||
session_start = f"""
|
||||
{{
|
||||
"event": {{
|
||||
"sessionStart": {{
|
||||
"inferenceConfiguration": {{
|
||||
"maxTokens": {self._params.max_tokens},
|
||||
"topP": {self._params.top_p},
|
||||
"temperature": {self._params.temperature}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
await self._send_client_event(session_start)
|
||||
|
||||
async def _send_prompt_start_event(self, tools: List[Any]):
|
||||
if not self._prompt_name:
|
||||
return
|
||||
|
||||
tools_config = (
|
||||
f""",
|
||||
"toolUseOutputConfiguration": {{
|
||||
"mediaType": "application/json"
|
||||
}},
|
||||
"toolConfiguration": {{
|
||||
"tools": {json.dumps(tools)}
|
||||
}}
|
||||
"""
|
||||
if tools
|
||||
else ""
|
||||
)
|
||||
|
||||
prompt_start = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"promptStart": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"textOutputConfiguration": {{
|
||||
"mediaType": "text/plain"
|
||||
}},
|
||||
"audioOutputConfiguration": {{
|
||||
"mediaType": "audio/lpcm",
|
||||
"sampleRateHertz": {self._params.output_sample_rate},
|
||||
"sampleSizeBits": {self._params.output_sample_size},
|
||||
"channelCount": {self._params.output_channel_count},
|
||||
"voiceId": "{self._voice_id}",
|
||||
"encoding": "base64",
|
||||
"audioType": "SPEECH"
|
||||
}}{tools_config}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(prompt_start)
|
||||
|
||||
async def _send_audio_input_start_event(self):
|
||||
if not self._prompt_name:
|
||||
return
|
||||
|
||||
audio_content_start = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"contentStart": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{self._input_audio_content_name}",
|
||||
"type": "AUDIO",
|
||||
"interactive": true,
|
||||
"role": "USER",
|
||||
"audioInputConfiguration": {{
|
||||
"mediaType": "audio/lpcm",
|
||||
"sampleRateHertz": {self._params.input_sample_rate},
|
||||
"sampleSizeBits": {self._params.input_sample_size},
|
||||
"channelCount": {self._params.input_channel_count},
|
||||
"audioType": "SPEECH",
|
||||
"encoding": "base64"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(audio_content_start)
|
||||
|
||||
async def _send_text_event(self, text: str, role: Role):
|
||||
if not self._stream or not self._prompt_name or not text:
|
||||
return
|
||||
|
||||
content_name = str(uuid.uuid4())
|
||||
|
||||
text_content_start = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"contentStart": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}",
|
||||
"type": "TEXT",
|
||||
"interactive": true,
|
||||
"role": "{role.value}",
|
||||
"textInputConfiguration": {{
|
||||
"mediaType": "text/plain"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(text_content_start)
|
||||
|
||||
escaped_text = json.dumps(text) # includes quotes
|
||||
text_input = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"textInput": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}",
|
||||
"content": {escaped_text}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(text_input)
|
||||
|
||||
text_content_end = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"contentEnd": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(text_content_end)
|
||||
|
||||
async def _send_user_audio_event(self, audio: bytes):
|
||||
if not self._stream:
|
||||
return
|
||||
|
||||
blob = base64.b64encode(audio)
|
||||
audio_event = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"audioInput": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{self._input_audio_content_name}",
|
||||
"content": "{blob.decode("utf-8")}"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(audio_event)
|
||||
|
||||
async def _send_session_end_events(self):
|
||||
if not self._stream or not self._prompt_name:
|
||||
return
|
||||
|
||||
prompt_end = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"promptEnd": {{
|
||||
"promptName": "{self._prompt_name}"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(prompt_end)
|
||||
|
||||
session_end = """
|
||||
{
|
||||
"event": {
|
||||
"sessionEnd": {}
|
||||
}
|
||||
}
|
||||
"""
|
||||
await self._send_client_event(session_end)
|
||||
|
||||
async def _send_tool_result(self, tool_call_id, result):
|
||||
if not self._stream or not self._prompt_name:
|
||||
return
|
||||
|
||||
content_name = str(uuid.uuid4())
|
||||
|
||||
result_content_start = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"contentStart": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}",
|
||||
"interactive": false,
|
||||
"type": "TOOL",
|
||||
"role": "TOOL",
|
||||
"toolResultInputConfiguration": {{
|
||||
"toolUseId": "{tool_call_id}",
|
||||
"type": "TEXT",
|
||||
"textInputConfiguration": {{
|
||||
"mediaType": "text/plain"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(result_content_start)
|
||||
|
||||
result_content = json.dumps(
|
||||
{
|
||||
"event": {
|
||||
"toolResult": {
|
||||
"promptName": self._prompt_name,
|
||||
"contentName": content_name,
|
||||
"content": json.dumps(result) if isinstance(result, dict) else result,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
await self._send_client_event(result_content)
|
||||
|
||||
result_content_end = f"""
|
||||
{{
|
||||
"event": {{
|
||||
"contentEnd": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
await self._send_client_event(result_content_end)
|
||||
|
||||
async def _send_client_event(self, event_json: str):
|
||||
if not self._stream: # should never happen
|
||||
return
|
||||
|
||||
event = InvokeModelWithBidirectionalStreamInputChunk(
|
||||
value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
|
||||
)
|
||||
await self._stream.input_stream.send(event)
|
||||
|
||||
#
|
||||
# LLM communication: output events (LLM -> pipecat)
|
||||
#
|
||||
|
||||
# Receive events for the session.
|
||||
# A few different kinds of content can be delivered:
|
||||
# - Transcription of user audio
|
||||
# - Tool use
|
||||
# - Text preview of planned response speech before audio delivered
|
||||
# - User interruption notification
|
||||
# - Text of response speech that whose audio was actually delivered
|
||||
# - Audio of response speech
|
||||
# Each piece of content is wrapped by "contentStart" and "contentEnd" events. The content is
|
||||
# delivered sequentially: one piece of content will end before another starts.
|
||||
# The overall completion is wrapped by "completionStart" and "completionEnd" events.
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
while self._stream and not self._disconnecting:
|
||||
output = await self._stream.await_output()
|
||||
result = await output[1].receive()
|
||||
|
||||
if result.value and result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
json_data = json.loads(response_data)
|
||||
|
||||
if "event" in json_data:
|
||||
event_json = json_data["event"]
|
||||
if "completionStart" in event_json:
|
||||
# Handle the LLM completion starting
|
||||
await self._handle_completion_start_event(event_json)
|
||||
elif "contentStart" in event_json:
|
||||
# Handle a piece of content starting
|
||||
await self._handle_content_start_event(event_json)
|
||||
elif "textOutput" in event_json:
|
||||
# Handle text output content
|
||||
await self._handle_text_output_event(event_json)
|
||||
elif "audioOutput" in event_json:
|
||||
# Handle audio output content
|
||||
await self._handle_audio_output_event(event_json)
|
||||
elif "toolUse" in event_json:
|
||||
# Handle tool use
|
||||
await self._handle_tool_use_event(event_json)
|
||||
elif "contentEnd" in event_json:
|
||||
# Handle a piece of content ending
|
||||
await self._handle_content_end_event(event_json)
|
||||
elif "completionEnd" in event_json:
|
||||
# Handle the LLM completion ending
|
||||
await self._handle_completion_end_event(event_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error processing responses: {e}")
|
||||
if self._wants_connection:
|
||||
await self.reset_conversation()
|
||||
|
||||
async def _handle_completion_start_event(self, event_json):
|
||||
pass
|
||||
|
||||
async def _handle_content_start_event(self, event_json):
|
||||
content_start = event_json["contentStart"]
|
||||
type = content_start["type"]
|
||||
role = content_start["role"]
|
||||
generation_stage = None
|
||||
if "additionalModelFields" in content_start:
|
||||
additional_model_fields = json.loads(content_start["additionalModelFields"])
|
||||
generation_stage = additional_model_fields.get("generationStage")
|
||||
|
||||
# Bookkeeping: track current content being received
|
||||
content = CurrentContent(
|
||||
type=ContentType(type),
|
||||
role=Role(role),
|
||||
text_stage=TextStage(generation_stage) if generation_stage else None,
|
||||
text_content=None,
|
||||
)
|
||||
self._content_being_received = content
|
||||
|
||||
if content.role == Role.ASSISTANT:
|
||||
if content.type == ContentType.AUDIO:
|
||||
# Note that an assistant response can comprise of multiple audio blocks
|
||||
if not self._assistant_is_responding:
|
||||
# The assistant has started responding.
|
||||
self._assistant_is_responding = True
|
||||
await self._report_assistant_response_started()
|
||||
|
||||
async def _handle_text_output_event(self, event_json):
|
||||
if not self._content_being_received: # should never happen
|
||||
return
|
||||
content = self._content_being_received
|
||||
|
||||
text_content = event_json["textOutput"]["content"]
|
||||
|
||||
# Bookkeeping: augment the current content being received with text
|
||||
# Assumption: only one text content per content block
|
||||
content.text_content = text_content
|
||||
|
||||
async def _handle_audio_output_event(self, event_json):
|
||||
if not self._content_being_received: # should never happen
|
||||
return
|
||||
|
||||
# Get audio
|
||||
audio_content = event_json["audioOutput"]["content"]
|
||||
|
||||
# Push audio frame
|
||||
audio = base64.b64decode(audio_content)
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio,
|
||||
sample_rate=self._params.output_sample_rate,
|
||||
num_channels=self._params.output_channel_count,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_tool_use_event(self, event_json):
|
||||
if not self._content_being_received or not self._context: # should never happen
|
||||
return
|
||||
|
||||
# Get tool use details
|
||||
tool_use = event_json["toolUse"]
|
||||
function_name = tool_use["toolName"]
|
||||
tool_call_id = tool_use["toolUseId"]
|
||||
arguments = json.loads(tool_use["content"])
|
||||
|
||||
# Call tool function
|
||||
if self.has_function(function_name):
|
||||
if function_name in self._functions.keys() or None in self._functions.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
else:
|
||||
raise AWSNovaSonicUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def _handle_content_end_event(self, event_json):
|
||||
if not self._content_being_received: # should never happen
|
||||
return
|
||||
content = self._content_being_received
|
||||
|
||||
content_end = event_json["contentEnd"]
|
||||
stop_reason = content_end["stopReason"]
|
||||
|
||||
# Bookkeeping: clear current content being received
|
||||
self._content_being_received = None
|
||||
|
||||
if content.role == Role.ASSISTANT:
|
||||
if content.type == ContentType.TEXT:
|
||||
# Ignore non-final text, and the "interrupted" message (which isn't meaningful text)
|
||||
if content.text_stage == TextStage.FINAL and stop_reason != "INTERRUPTED":
|
||||
if self._assistant_is_responding:
|
||||
# Text added to the ongoing assistant response
|
||||
await self._report_assistant_response_text_added(content.text_content)
|
||||
elif content.role == Role.USER:
|
||||
if content.type == ContentType.TEXT:
|
||||
if content.text_stage == TextStage.FINAL:
|
||||
# User transcription text added
|
||||
await self._report_user_transcription_text_added(content.text_content)
|
||||
|
||||
async def _handle_completion_end_event(self, event_json):
|
||||
pass
|
||||
|
||||
async def _report_assistant_response_started(self):
|
||||
logger.debug("Assistant response started")
|
||||
|
||||
# Report that the assistant has started their response.
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Report that equivalent of TTS (this is a speech-to-speech model) started
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
|
||||
async def _report_assistant_response_text_added(self, text):
|
||||
if not self._context: # should never happen
|
||||
return
|
||||
|
||||
logger.debug(f"Assistant response text added: {text}")
|
||||
|
||||
# Report some text added to the ongoing assistant response
|
||||
await self.push_frame(LLMTextFrame(text))
|
||||
|
||||
# Report some text added to the *equivalent* of TTS (this is a speech-to-speech model)
|
||||
await self.push_frame(TTSTextFrame(text))
|
||||
|
||||
# TODO: this is a (hopefully temporary) HACK. Here we directly manipulate the context rather
|
||||
# than relying on the frames pushed to the assistant context aggregator. The pattern of
|
||||
# receiving full-sentence text after the assistant has spoken does not easily fit with the
|
||||
# Pipecat expectation of chunks of text streaming in while the assistant is speaking.
|
||||
# Interruption handling was especially challenging. Rather than spend days trying to fit a
|
||||
# square peg in a round hole, I decided on this hack for the time being. We can most cleanly
|
||||
# abandon this hack if/when AWS Nova Sonic implements streaming smaller text chunks
|
||||
# interspersed with audio. Note that when we move away from this hack, we need to make sure
|
||||
# that on an interruption we avoid sending LLMFullResponseEndFrame, which gets the
|
||||
# LLMAssistantContextAggregator into a bad state.
|
||||
self._context.buffer_assistant_text(text)
|
||||
|
||||
async def _report_assistant_response_ended(self):
|
||||
if not self._context: # should never happen
|
||||
return
|
||||
|
||||
logger.debug("Assistant response ended")
|
||||
|
||||
# Report that the assistant has finished their response.
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
# Report that equivalent of TTS (this is a speech-to-speech model) stopped.
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
|
||||
# For an explanation of this hack, see _report_assistant_response_text_added.
|
||||
self._context.flush_aggregated_assistant_text()
|
||||
|
||||
async def _report_user_transcription_text_added(self, text):
|
||||
if not self._context: # should never happen
|
||||
return
|
||||
|
||||
logger.debug(f"User transcription text added: {text}")
|
||||
|
||||
# Manually add new user transcription text to context.
|
||||
# We can't rely on the user context aggregator to do this since it's upstream from the LLM.
|
||||
self._context.add_user_transcription_text(text)
|
||||
|
||||
# Report that some new user transcription text is available.
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(text=text, user_id="", timestamp=time_now_iso8601())
|
||||
)
|
||||
|
||||
#
|
||||
# context
|
||||
#
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AWSNovaSonicContextAggregatorPair:
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
user = AWSNovaSonicUserContextAggregator(context=context, params=user_params)
|
||||
assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params)
|
||||
|
||||
return AWSNovaSonicContextAggregatorPair(user, assistant)
|
||||
|
||||
#
|
||||
# assistant response trigger (HACK)
|
||||
#
|
||||
|
||||
# Class variable
|
||||
AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION = (
|
||||
"Start speaking when you hear the user say 'ready', but don't consider that 'ready' to be "
|
||||
"a meaningful part of the conversation other than as a trigger for you to start speaking."
|
||||
)
|
||||
|
||||
async def trigger_assistant_response(self):
|
||||
if self._triggering_assistant_response:
|
||||
return False
|
||||
|
||||
self._triggering_assistant_response = True
|
||||
|
||||
# Read audio bytes, if we don't already have them cached
|
||||
if not self._assistant_response_trigger_audio:
|
||||
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
# Send the trigger audio, if we're fully connected and set up
|
||||
if self._connected_time is not None:
|
||||
await self._send_assistant_response_trigger()
|
||||
self._triggering_assistant_response = False
|
||||
|
||||
async def _send_assistant_response_trigger(self):
|
||||
if (
|
||||
not self._assistant_response_trigger_audio or self._connected_time is None
|
||||
): # should never happen
|
||||
return
|
||||
|
||||
logger.debug("Sending assistant response trigger...")
|
||||
|
||||
chunk_duration = 0.02 # what we might get from InputAudioRawFrame
|
||||
chunk_size = int(
|
||||
chunk_duration
|
||||
* self._params.input_sample_rate
|
||||
* self._params.input_channel_count
|
||||
* (self._params.input_sample_size / 8)
|
||||
) # e.g. 0.02 seconds of 16-bit (2-byte) PCM mono audio at 16kHz is 640 bytes
|
||||
|
||||
# Lead with a bit of blank audio, if needed.
|
||||
# It seems like the LLM can't quite "hear" the first little bit of audio sent on a
|
||||
# connection.
|
||||
current_time = time.time()
|
||||
max_blank_audio_duration = 0.5
|
||||
blank_audio_duration = (
|
||||
max_blank_audio_duration - (current_time - self._connected_time)
|
||||
if self._connected_time is not None
|
||||
and (current_time - self._connected_time) < max_blank_audio_duration
|
||||
else None
|
||||
)
|
||||
if blank_audio_duration:
|
||||
logger.debug(
|
||||
f"Leading assistant response trigger with {blank_audio_duration}s of blank audio"
|
||||
)
|
||||
blank_audio_chunk = b"\x00" * chunk_size
|
||||
num_chunks = int(blank_audio_duration / chunk_duration)
|
||||
for _ in range(num_chunks):
|
||||
await self._send_user_audio_event(blank_audio_chunk)
|
||||
await asyncio.sleep(chunk_duration)
|
||||
|
||||
# Send trigger audio
|
||||
# NOTE: this audio *will* be transcribed and eventually make it into the context. That's OK:
|
||||
# if we ever need to seed this service again with context it would make sense to include it
|
||||
# since the instruction (i.e. the "wait for the trigger" instruction) will be part of the
|
||||
# context as well.
|
||||
audio_chunks = [
|
||||
self._assistant_response_trigger_audio[i : i + chunk_size]
|
||||
for i in range(0, len(self._assistant_response_trigger_audio), chunk_size)
|
||||
]
|
||||
for chunk in audio_chunks:
|
||||
await self._send_user_audio_event(chunk)
|
||||
await asyncio.sleep(chunk_duration)
|
||||
217
src/pipecat/services/aws_nova_sonic/context.py
Normal file
217
src/pipecat/services/aws_nova_sonic/context.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistoryMessage:
|
||||
role: Role # only USER and ASSISTANT
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistory:
|
||||
system_instruction: str = None
|
||||
messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMContext(OpenAILLMContext):
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self, system_instruction: str = ""):
|
||||
self._assistant_text = ""
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_nova_sonic(
|
||||
obj: OpenAILLMContext, system_instruction: str
|
||||
) -> "AWSNovaSonicLLMContext":
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext):
|
||||
obj.__class__ = AWSNovaSonicLLMContext
|
||||
obj.__setup_local(system_instruction)
|
||||
return obj
|
||||
|
||||
# NOTE: this method has the side-effect of updating _system_instruction from messages
|
||||
def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory:
|
||||
history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction)
|
||||
|
||||
# Bail if there are no messages
|
||||
if not self.messages:
|
||||
return history
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into "instruction"
|
||||
if messages[0].get("role") == "system":
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
history.system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
history.system_instruction = content[0].get("text")
|
||||
if history.system_instruction:
|
||||
self._system_instruction = history.system_instruction
|
||||
|
||||
# Process remaining messages to fill out conversation history.
|
||||
# Nova Sonic supports "user" and "assistant" messages in history.
|
||||
for message in messages:
|
||||
history_message = self.from_standard_message(message)
|
||||
if history_message:
|
||||
history.messages.append(history_message)
|
||||
|
||||
return history
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
# If we have a system instruction and messages doesn't already contain it, add it
|
||||
if self._system_instruction and not (messages and messages[0].get("role") == "system"):
|
||||
messages.insert(0, {"role": "system", "content": self._system_instruction})
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
|
||||
role = message.get("role")
|
||||
if message.get("role") == "user" or message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
# There won't be content if this is an assistant tool call entry.
|
||||
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
|
||||
# history
|
||||
if content:
|
||||
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
|
||||
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
|
||||
# Sonic conversation history
|
||||
|
||||
def add_user_transcription_text(self, text):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text}],
|
||||
}
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (user): {self.get_messages_for_logging()}")
|
||||
|
||||
def buffer_assistant_text(self, text):
|
||||
self._assistant_text += text
|
||||
# logger.debug(f"Assistant text buffered: {self._assistant_text}")
|
||||
|
||||
def flush_aggregated_assistant_text(self):
|
||||
if not self._assistant_text:
|
||||
return
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": self._assistant_text}],
|
||||
}
|
||||
self._assistant_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicMessagesUpdateFrame(DataFrame):
|
||||
context: AWSNovaSonicLLMContext
|
||||
|
||||
|
||||
class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Parent does not push LLMMessagesUpdateFrame
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context))
|
||||
|
||||
|
||||
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
# HACK: For now, disable the context aggregator by making it just pass through all frames
|
||||
# that the parent handles (except the function call stuff, which we still need).
|
||||
# For an explanation of this hack, see
|
||||
# AWSNovaSonicLLMService._report_assistant_response_text_added.
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
StartInterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
TextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
UserImageRawFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
),
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the LLM
|
||||
# itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side
|
||||
# context. Let's push a special frame to do that.
|
||||
await self.push_frame(
|
||||
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicContextAggregatorPair:
|
||||
_user: AWSNovaSonicUserContextAggregator
|
||||
_assistant: AWSNovaSonicAssistantContextAggregator
|
||||
|
||||
def user(self) -> AWSNovaSonicUserContextAggregator:
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> AWSNovaSonicAssistantContextAggregator:
|
||||
return self._assistant
|
||||
14
src/pipecat/services/aws_nova_sonic/frames.py
Normal file
14
src/pipecat/services/aws_nova_sonic/frames.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
|
||||
result_frame: FunctionCallResultFrame
|
||||
BIN
src/pipecat/services/aws_nova_sonic/ready.wav
Normal file
BIN
src/pipecat/services/aws_nova_sonic/ready.wav
Normal file
Binary file not shown.
@@ -1,13 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .metrics import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "canonical", "canonical.metrics")
|
||||
@@ -1,230 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
import wave
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
|
||||
try:
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Canonical Metrics, you need to `pip install pipecat-ai[canonical]`. "
|
||||
+ "Also, set the `CANONICAL_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# Multipart upload part size in bytes, cannot be smaller than 5MB
|
||||
PART_SIZE = 1024 * 1024 * 5
|
||||
|
||||
|
||||
class CanonicalMetricsService(AIService):
|
||||
"""Initialize a CanonicalAudioProcessor instance.
|
||||
|
||||
This class uses an AudioBufferProcessor to get the conversation audio and
|
||||
uploads it to Canonical Voice API for audio processing.
|
||||
|
||||
Args:
|
||||
call_id (str): Your unique identifier for the call. This is used to match the call in the Canonical Voice system to the call in your system.
|
||||
assistant (str): Identifier for the AI assistant. This can be whatever you want, it's intended for you convenience so you can distinguish
|
||||
between different assistants and a grouping mechanism for calls.
|
||||
assistant_speaks_first (bool, optional): Indicates if the assistant speaks first in the conversation. Defaults to True.
|
||||
output_dir (str, optional): Directory to save temporary audio files. Defaults to "recordings".
|
||||
|
||||
Attributes:
|
||||
call_id (str): Stores the unique call identifier.
|
||||
assistant (str): Stores the assistant identifier.
|
||||
assistant_speaks_first (bool): Indicates whether the assistant speaks first.
|
||||
output_dir (str): Directory path for saving temporary audio files.
|
||||
|
||||
The constructor also ensures that the output directory exists.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
call_id: str,
|
||||
assistant: str,
|
||||
api_key: str,
|
||||
api_url: str = "https://voiceapp.canonical.chat/api/v1",
|
||||
assistant_speaks_first: bool = True,
|
||||
output_dir: str = "recordings",
|
||||
audio_buffer_processor: Optional[AudioBufferProcessor] = None,
|
||||
context: Optional[OpenAILLMContext] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
# Validate that at least one of audio_buffer_processor or context is provided
|
||||
if audio_buffer_processor is None and context is None:
|
||||
raise ValueError("At least one of audio_buffer_processor or context must be specified")
|
||||
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._audio_buffer_processor = audio_buffer_processor
|
||||
self._api_key = api_key
|
||||
self._api_url = api_url
|
||||
self._call_id = call_id
|
||||
self._assistant = assistant
|
||||
self._assistant_speaks_first = assistant_speaks_first
|
||||
self._output_dir = output_dir
|
||||
self._context = context
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._process_completion()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._process_completion()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_completion(self):
|
||||
if self._audio_buffer_processor is not None:
|
||||
await self._process_audio()
|
||||
elif self._context is not None:
|
||||
await self._process_transcript()
|
||||
|
||||
async def _process_transcript(self):
|
||||
params = {
|
||||
"callId": self._call_id,
|
||||
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
|
||||
"transcript": self._context.messages,
|
||||
}
|
||||
response = await self._aiohttp_session.post(
|
||||
f"{self._api_url}/call",
|
||||
headers=self._request_headers(),
|
||||
json=params,
|
||||
)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to process transcript: {await response.text()}")
|
||||
|
||||
async def _process_audio(self):
|
||||
audio_buffer_processor = self._audio_buffer_processor
|
||||
|
||||
if not audio_buffer_processor.has_audio():
|
||||
return
|
||||
|
||||
os.makedirs(self._output_dir, exist_ok=True)
|
||||
filename = self._get_output_filename()
|
||||
audio = audio_buffer_processor.merge_audio_buffers()
|
||||
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(audio_buffer_processor.num_channels)
|
||||
wf.setframerate(audio_buffer_processor.sample_rate)
|
||||
wf.writeframes(audio)
|
||||
async with aiofiles.open(filename, "wb") as file:
|
||||
await file.write(buffer.getvalue())
|
||||
|
||||
try:
|
||||
await self._multipart_upload(filename)
|
||||
await aiofiles.os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload recording: {e}")
|
||||
|
||||
def _get_output_filename(self):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{self._output_dir}/{timestamp}-{uuid.uuid4().hex}.wav"
|
||||
|
||||
def _request_headers(self):
|
||||
return {"Content-Type": "application/json", "X-Canonical-Api-Key": self._api_key}
|
||||
|
||||
async def _multipart_upload(self, file_path: str):
|
||||
upload_request, upload_response = await self._request_upload(file_path)
|
||||
if upload_request is None or upload_response is None:
|
||||
return
|
||||
parts = await self._upload_parts(file_path, upload_response)
|
||||
if parts is None:
|
||||
return
|
||||
await self._upload_complete(parts, upload_request, upload_response)
|
||||
|
||||
async def _request_upload(self, file_path: str) -> Tuple[Dict, Dict]:
|
||||
filename = os.path.basename(file_path)
|
||||
filesize = os.path.getsize(file_path)
|
||||
numparts = int((filesize + PART_SIZE - 1) / PART_SIZE)
|
||||
|
||||
params = {
|
||||
"filename": filename,
|
||||
"parts": numparts,
|
||||
"callId": self._call_id,
|
||||
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
|
||||
}
|
||||
logger.debug(f"Requesting presigned URLs for {numparts} parts")
|
||||
response = await self._aiohttp_session.post(
|
||||
f"{self._api_url}/recording/uploadRequest", headers=self._request_headers(), json=params
|
||||
)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to get presigned URLs: {await response.text()}")
|
||||
return None, None
|
||||
response_json = await response.json()
|
||||
return params, response_json
|
||||
|
||||
async def _upload_parts(self, file_path: str, upload_response: Dict) -> List[Dict]:
|
||||
urls = upload_response["urls"]
|
||||
parts = []
|
||||
try:
|
||||
async with aiofiles.open(file_path, "rb") as file:
|
||||
for partnum, upload_url in enumerate(urls, start=1):
|
||||
data = await file.read(PART_SIZE)
|
||||
if not data:
|
||||
break
|
||||
|
||||
response = await self._aiohttp_session.put(upload_url, data=data)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to upload part {partnum}: {await response.text()}")
|
||||
return None
|
||||
|
||||
etag = response.headers["ETag"]
|
||||
parts.append({"partnum": str(partnum), "etag": etag})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Multipart upload aborted, an error occurred: {str(e)}")
|
||||
return parts
|
||||
|
||||
async def _upload_complete(
|
||||
self, parts: List[Dict], upload_request: Dict, upload_response: Dict
|
||||
):
|
||||
params = {
|
||||
"filename": upload_request["filename"],
|
||||
"parts": parts,
|
||||
"slug": upload_response["slug"],
|
||||
"callId": self._call_id,
|
||||
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
|
||||
}
|
||||
if self._context is not None:
|
||||
params["transcript"] = self._context.messages
|
||||
|
||||
logger.debug(f"Completing upload for {params['filename']}")
|
||||
logger.debug(f"Slug: {params['slug']}")
|
||||
response = await self._aiohttp_session.post(
|
||||
f"{self._api_url}/recording/uploadComplete",
|
||||
headers=self._request_headers(),
|
||||
json=params,
|
||||
)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to complete upload: {await response.text()}")
|
||||
return
|
||||
@@ -30,7 +30,7 @@ class DeepgramTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice: str = "aura-helios-en",
|
||||
voice: str = "aura-2-helena-en",
|
||||
base_url: str = "",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
|
||||
@@ -7,11 +7,12 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -26,7 +27,10 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService, WordTTSService
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextWordTTSService,
|
||||
WordTTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
@@ -159,26 +163,17 @@ def calculate_word_times(
|
||||
return word_times
|
||||
|
||||
|
||||
class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = None
|
||||
optimize_streaming_latency: Optional[str] = None
|
||||
stability: Optional[float] = None
|
||||
similarity_boost: Optional[float] = None
|
||||
style: Optional[float] = None
|
||||
use_speaker_boost: Optional[bool] = None
|
||||
speed: Optional[float] = None
|
||||
auto_mode: Optional[bool] = True
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_voice_settings(self):
|
||||
stability = self.stability
|
||||
similarity_boost = self.similarity_boost
|
||||
if (stability is None) != (similarity_boost is None):
|
||||
raise ValueError(
|
||||
"Both 'stability' and 'similarity_boost' must be provided when using voice settings"
|
||||
)
|
||||
return self
|
||||
enable_ssml_parsing: Optional[bool] = None
|
||||
enable_logging: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -220,13 +215,14 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else None,
|
||||
"optimize_streaming_latency": params.optimize_streaming_latency,
|
||||
"stability": params.stability,
|
||||
"similarity_boost": params.similarity_boost,
|
||||
"style": params.style,
|
||||
"use_speaker_boost": params.use_speaker_boost,
|
||||
"speed": params.speed,
|
||||
"auto_mode": str(params.auto_mode).lower(),
|
||||
"enable_ssml_parsing": params.enable_ssml_parsing,
|
||||
"enable_logging": params.enable_logging,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
@@ -238,6 +234,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
self._started = False
|
||||
self._cumulative_time = 0
|
||||
|
||||
# Context management for v1 multi API
|
||||
self._context_id = None
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
|
||||
@@ -253,15 +251,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
async def set_model(self, model: str):
|
||||
await super().set_model(model)
|
||||
logger.info(f"Switching TTS model to: [{model}]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
# No need to disconnect/reconnect for model changes with multi-context API
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
# If voice changes, we don't need to reconnect, just use a new context
|
||||
if not prev_voice == self._voice_id:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
@@ -278,8 +274,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self):
|
||||
if self._websocket:
|
||||
msg = {"text": " ", "flush": True}
|
||||
if self._websocket and self._context_id:
|
||||
msg = {"context_id": self._context_id, "flush": True}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -319,10 +315,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
voice_id = self._voice_id
|
||||
model = self.model_name
|
||||
output_format = self._output_format
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/multi-stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
|
||||
|
||||
if self._settings["optimize_streaming_latency"]:
|
||||
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
|
||||
if self._settings["enable_ssml_parsing"]:
|
||||
url += f"&enable_ssml_parsing={self._settings['enable_ssml_parsing']}"
|
||||
|
||||
if self._settings["enable_logging"]:
|
||||
url += f"&enable_logging={self._settings['enable_logging']}"
|
||||
|
||||
# Language can only be used with the ELEVENLABS_MULTILINGUAL_MODELS
|
||||
language = self._settings["language"]
|
||||
@@ -337,14 +336,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
# Set max websocket message size to 16MB for large audio responses
|
||||
self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024)
|
||||
|
||||
# According to ElevenLabs, we should always start with a single space.
|
||||
msg: Dict[str, Any] = {
|
||||
"text": " ",
|
||||
"xi_api_key": self._api_key,
|
||||
}
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -356,12 +347,15 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from ElevenLabs")
|
||||
await self._websocket.send(json.dumps({"text": ""}))
|
||||
# Close all contexts and the socket
|
||||
if self._context_id:
|
||||
await self._websocket.send(json.dumps({"close_socket": True}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
|
||||
def _get_websocket(self):
|
||||
@@ -369,9 +363,35 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
if self._context_id and self._websocket:
|
||||
logger.trace(f"Closing context {self._context_id} due to interruption")
|
||||
try:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
# Check if this message belongs to the current context
|
||||
# The default context may return null/None for context_id
|
||||
received_ctx_id = msg.get("context_id")
|
||||
if (
|
||||
self._context_id is not None
|
||||
and received_ctx_id is not None
|
||||
and received_ctx_id != self._context_id
|
||||
):
|
||||
logger.trace(f"Ignoring message from different context: {received_ctx_id}")
|
||||
continue
|
||||
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
@@ -383,20 +403,45 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
|
||||
await self.add_word_timestamps(word_times)
|
||||
self._cumulative_time = word_times[-1][1]
|
||||
if msg.get("is_final"):
|
||||
logger.trace(f"Received final message for context {received_ctx_id}")
|
||||
# Context has finished
|
||||
if self._context_id == received_ctx_id:
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
await self._send_text("")
|
||||
# Send an empty message to keep the connection alive
|
||||
if self._websocket and self._websocket.open:
|
||||
await self._websocket.send(json.dumps({}))
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
|
||||
async def _send_text(self, text: str):
|
||||
if self._websocket:
|
||||
msg = {"text": text + " "}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
if not self._context_id:
|
||||
# First message for a new context - need a space to initialize
|
||||
msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key}
|
||||
|
||||
# Add voice settings only in first message for a context
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
self._context_id = msg["context_id"]
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
|
||||
# Now send the actual text content
|
||||
msg = {"text": text, "context_id": self._context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
else:
|
||||
# Continuing with an existing context
|
||||
msg = {"text": text, "context_id": self._context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
@@ -406,6 +451,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
# Close previous context if there was one
|
||||
if self._context_id and not self._started:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
self._context_id = None
|
||||
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
@@ -417,8 +469,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
|
||||
@@ -9,8 +9,9 @@ from typing import List, Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.observers.base_observer import FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame
|
||||
|
||||
|
||||
@@ -27,18 +28,13 @@ class RTVIBotLLMSearchResponseMessage(BaseModel):
|
||||
|
||||
|
||||
class GoogleRTVIObserver(RTVIObserver):
|
||||
def __init__(self, rtvi: FrameProcessor):
|
||||
def __init__(self, rtvi: RTVIProcessor):
|
||||
super().__init__(rtvi)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
await super().on_push_frame(src, dst, frame, direction, timestamp)
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
await super().on_push_frame(data)
|
||||
|
||||
frame = data.frame
|
||||
|
||||
if isinstance(frame, LLMSearchResponseFrame):
|
||||
await self._handle_llm_search_response_frame(frame)
|
||||
|
||||
@@ -190,6 +190,7 @@ class LLMService(AIService):
|
||||
function_name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
text_content: Optional[str] = None,
|
||||
video_source: Optional[str] = None,
|
||||
):
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(
|
||||
@@ -197,6 +198,7 @@ class LLMService(AIService):
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
context=text_content,
|
||||
video_source=video_source,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
@@ -577,15 +577,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
arguments = json.loads(item.arguments)
|
||||
if self.has_function(function_name):
|
||||
run_llm = index == total_items - 1
|
||||
if function_name in self._functions.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
elif None in self._functions.keys():
|
||||
if function_name in self._functions.keys() or None in self._functions.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_id,
|
||||
|
||||
@@ -15,7 +15,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -42,14 +42,10 @@ class HeartbeatsObserver(BaseObserver):
|
||||
self._target = target
|
||||
self._callback = heartbeat_callback
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
src = data.source
|
||||
frame = data.frame
|
||||
|
||||
if src == self._target and isinstance(frame, HeartbeatFrame):
|
||||
await self._callback(self._target, frame)
|
||||
|
||||
|
||||
@@ -122,6 +122,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Configure VAD analyzer.
|
||||
if self._params.vad_analyzer:
|
||||
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
|
||||
|
||||
# Configure End of turn analyzer.
|
||||
if self._params.turn_analyzer:
|
||||
self._params.turn_analyzer.set_sample_rate(self._sample_rate)
|
||||
@@ -129,10 +130,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Start audio filter.
|
||||
if self._params.audio_in_filter:
|
||||
await self._params.audio_in_filter.start(self._sample_rate)
|
||||
# Create audio input queue and task if needed.
|
||||
if not self._audio_task and self._params.audio_in_enabled:
|
||||
self._audio_in_queue = asyncio.Queue()
|
||||
self._audio_task = self.create_task(self._audio_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the audio input task to finish.
|
||||
@@ -149,6 +146,13 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.cancel_task(self._audio_task)
|
||||
self._audio_task = None
|
||||
|
||||
async def set_transport_ready(self, frame: StartFrame):
|
||||
"""To be called when the transport is ready to stream."""
|
||||
# Create audio input queue and task if needed.
|
||||
if not self._audio_task and self._params.audio_in_enabled:
|
||||
self._audio_in_queue = asyncio.Queue()
|
||||
self._audio_task = self.create_task(self._audio_task_handler())
|
||||
|
||||
async def push_audio_frame(self, frame: InputAudioRawFrame):
|
||||
if self._params.audio_in_enabled:
|
||||
await self._audio_in_queue.put(frame)
|
||||
|
||||
@@ -78,6 +78,16 @@ class BaseOutputTransport(FrameProcessor):
|
||||
audio_bytes_10ms = int(self._sample_rate / 100) * self._params.audio_out_channels * 2
|
||||
self._audio_chunk_size = audio_bytes_10ms * self._params.audio_out_10ms_chunks
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
for _, sender in self._media_senders.items():
|
||||
await sender.stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
for _, sender in self._media_senders.items():
|
||||
await sender.cancel(frame)
|
||||
|
||||
async def set_transport_ready(self, frame: StartFrame):
|
||||
"""To be called when the transport is ready to stream."""
|
||||
# Register destinations.
|
||||
for destination in self._params.audio_out_destinations:
|
||||
await self.register_audio_destination(destination)
|
||||
@@ -112,14 +122,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
)
|
||||
await self._media_senders[destination].start(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
for _, sender in self._media_senders.items():
|
||||
await sender.stop(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
for _, sender in self._media_senders.items():
|
||||
await sender.cancel(frame)
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
pass
|
||||
|
||||
|
||||
@@ -61,6 +61,8 @@ class LocalAudioInputTransport(BaseInputTransport):
|
||||
)
|
||||
self._in_stream.start_stream()
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._in_stream:
|
||||
@@ -111,6 +113,8 @@ class LocalAudioOutputTransport(BaseOutputTransport):
|
||||
)
|
||||
self._out_stream.start_stream()
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._out_stream:
|
||||
|
||||
@@ -68,6 +68,8 @@ class TkInputTransport(BaseInputTransport):
|
||||
)
|
||||
self._in_stream.start_stream()
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._in_stream:
|
||||
@@ -124,6 +126,8 @@ class TkOutputTransport(BaseOutputTransport):
|
||||
)
|
||||
self._out_stream.start_stream()
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._out_stream:
|
||||
|
||||
@@ -131,6 +131,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
await self._client.trigger_client_connected()
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._monitor_websocket_task:
|
||||
@@ -204,6 +205,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
await self._client.setup(frame)
|
||||
await self._params.serializer.setup(frame)
|
||||
self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
|
||||
@@ -395,6 +395,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
|
||||
self._receive_audio_task = self.create_task(self._receive_audio())
|
||||
if not self._receive_video_task and self._params.video_in_enabled:
|
||||
self._receive_video_task = self.create_task(self._receive_video())
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._receive_audio_task:
|
||||
@@ -487,6 +488,7 @@ class SmallWebRTCOutputTransport(BaseOutputTransport):
|
||||
await super().start(frame)
|
||||
await self._client.setup(self._params, frame)
|
||||
await self._client.connect()
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
|
||||
@@ -136,6 +136,7 @@ class WebsocketClientInputTransport(BaseInputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
await self._session.setup(frame)
|
||||
await self._session.connect()
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
@@ -186,6 +187,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
await self._session.setup(frame)
|
||||
await self._session.connect()
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
|
||||
@@ -83,6 +83,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
if not self._server_task:
|
||||
self._server_task = self.create_task(self._server_task_handler())
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
@@ -195,6 +196,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
await super().start(frame)
|
||||
await self._params.serializer.setup(frame)
|
||||
self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
|
||||
@@ -175,6 +175,7 @@ class DailyCallbacks(BaseModel):
|
||||
"""Callback handlers for Daily events.
|
||||
|
||||
Attributes:
|
||||
on_active_speaker_changed: Called when the active speaker of the call has changed.
|
||||
on_joined: Called when bot successfully joined a room.
|
||||
on_left: Called when bot left a room.
|
||||
on_error: Called when an error occurs.
|
||||
@@ -201,6 +202,7 @@ class DailyCallbacks(BaseModel):
|
||||
on_recording_error: Called when recording encounters an error.
|
||||
"""
|
||||
|
||||
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_left: Callable[[], Awaitable[None]]
|
||||
on_error: Callable[[str], Awaitable[None]]
|
||||
@@ -698,7 +700,7 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
await self.update_subscriptions(participant_settings={participant_id: media})
|
||||
|
||||
self._audio_renderers[participant_id] = {audio_source: callback}
|
||||
self._audio_renderers.setdefault(participant_id, {})[audio_source] = callback
|
||||
|
||||
self._client.set_audio_renderer(
|
||||
participant_id,
|
||||
@@ -722,7 +724,7 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
await self.update_subscriptions(participant_settings={participant_id: media})
|
||||
|
||||
self._video_renderers[participant_id] = {video_source: callback}
|
||||
self._video_renderers.setdefault(participant_id, {})[video_source] = callback
|
||||
|
||||
self._client.set_video_renderer(
|
||||
participant_id,
|
||||
@@ -789,6 +791,9 @@ class DailyTransportClient(EventHandler):
|
||||
# Daily (EventHandler)
|
||||
#
|
||||
|
||||
def on_active_speaker_changed(self, participant):
|
||||
self._call_async_callback(self._callbacks.on_active_speaker_changed, participant)
|
||||
|
||||
def on_app_message(self, message: Any, sender: str):
|
||||
self._call_async_callback(self._callbacks.on_app_message, message, sender)
|
||||
|
||||
@@ -944,19 +949,23 @@ class DailyInputTransport(BaseInputTransport):
|
||||
self._audio_in_task = self.create_task(self._audio_in_task_handler())
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Setup client.
|
||||
await self._client.setup(frame)
|
||||
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
|
||||
# Setup client.
|
||||
await self._client.setup(frame)
|
||||
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
|
||||
# Indicate the transport that we are connected.
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
if self._params.audio_in_stream_on_start:
|
||||
self.start_audio_in_streaming()
|
||||
|
||||
@@ -1052,12 +1061,13 @@ class DailyInputTransport(BaseInputTransport):
|
||||
video_source: str = "camera",
|
||||
color_format: str = "RGB",
|
||||
):
|
||||
self._video_renderers[participant_id] = {
|
||||
video_source: {
|
||||
"framerate": framerate,
|
||||
"timestamp": 0,
|
||||
"render_next_frame": [],
|
||||
}
|
||||
if participant_id not in self._video_renderers:
|
||||
self._video_renderers[participant_id] = {}
|
||||
|
||||
self._video_renderers[participant_id][video_source] = {
|
||||
"framerate": framerate,
|
||||
"timestamp": 0,
|
||||
"render_next_frame": [],
|
||||
}
|
||||
|
||||
await self._client.capture_participant_video(
|
||||
@@ -1066,7 +1076,8 @@ class DailyInputTransport(BaseInputTransport):
|
||||
|
||||
async def request_participant_image(self, frame: UserImageRequestFrame):
|
||||
if frame.user_id in self._video_renderers:
|
||||
self._video_renderers[frame.user_id]["render_next_frame"].append(frame)
|
||||
video_source = frame.video_source if frame.video_source else "camera"
|
||||
self._video_renderers[frame.user_id][video_source]["render_next_frame"].append(frame)
|
||||
|
||||
async def _on_participant_video_frame(
|
||||
self, participant_id: str, video_frame: VideoFrame, video_source: str
|
||||
@@ -1125,20 +1136,23 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
self._initialized = False
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Setup client.
|
||||
await self._client.setup(frame)
|
||||
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
|
||||
# Setup client.
|
||||
await self._client.setup(frame)
|
||||
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
|
||||
# Indicate the transport that we are connected.
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop(frame)
|
||||
@@ -1201,6 +1215,7 @@ class DailyTransport(BaseTransport):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
|
||||
callbacks = DailyCallbacks(
|
||||
on_active_speaker_changed=self._on_active_speaker_changed,
|
||||
on_joined=self._on_joined,
|
||||
on_left=self._on_left,
|
||||
on_error=self._on_error,
|
||||
@@ -1236,6 +1251,7 @@ class DailyTransport(BaseTransport):
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
# these handlers.
|
||||
self._register_event_handler("on_active_speaker_changed")
|
||||
self._register_event_handler("on_joined")
|
||||
self._register_event_handler("on_left")
|
||||
self._register_event_handler("on_error")
|
||||
@@ -1370,6 +1386,9 @@ class DailyTransport(BaseTransport):
|
||||
async def update_remote_participants(self, remote_participants: Mapping[str, Any]):
|
||||
await self._client.update_remote_participants(remote_participants=remote_participants)
|
||||
|
||||
async def _on_active_speaker_changed(self, participant: Any):
|
||||
await self._call_event_handler("on_active_speaker_changed", participant)
|
||||
|
||||
async def _on_joined(self, data):
|
||||
await self._call_event_handler("on_joined", data)
|
||||
|
||||
|
||||
@@ -370,6 +370,7 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
await self._client.connect()
|
||||
if not self._audio_in_task and self._params.audio_in_enabled:
|
||||
self._audio_in_task = self.create_task(self._audio_in_task_handler())
|
||||
await self.set_transport_ready(frame)
|
||||
logger.info("LiveKitInputTransport started")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -441,6 +442,7 @@ class LiveKitOutputTransport(BaseOutputTransport):
|
||||
await super().start(frame)
|
||||
await self._client.setup(frame)
|
||||
await self._client.connect()
|
||||
await self.set_transport_ready(frame)
|
||||
logger.info("LiveKitOutputTransport started")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Coroutine, Optional, Set
|
||||
from typing import Coroutine, Dict, Optional, Sequence, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -69,14 +69,14 @@ class BaseTaskManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def current_tasks(self) -> Set[asyncio.Task]:
|
||||
def current_tasks(self) -> Sequence[asyncio.Task]:
|
||||
"""Returns the list of currently created/registered tasks."""
|
||||
pass
|
||||
|
||||
|
||||
class TaskManager(BaseTaskManager):
|
||||
def __init__(self) -> None:
|
||||
self._tasks: Set[asyncio.Task] = set()
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
@@ -179,16 +179,17 @@ class TaskManager(BaseTaskManager):
|
||||
finally:
|
||||
self._remove_task(task)
|
||||
|
||||
def current_tasks(self) -> Set[asyncio.Task]:
|
||||
def current_tasks(self) -> Sequence[asyncio.Task]:
|
||||
"""Returns the list of currently created/registered tasks."""
|
||||
return self._tasks
|
||||
return list(self._tasks.values())
|
||||
|
||||
def _add_task(self, task: asyncio.Task):
|
||||
self._tasks.add(task)
|
||||
name = task.get_name()
|
||||
self._tasks[name] = task
|
||||
|
||||
def _remove_task(self, task: asyncio.Task):
|
||||
name = task.get_name()
|
||||
try:
|
||||
self._tasks.remove(task)
|
||||
del self._tasks[name]
|
||||
except KeyError as e:
|
||||
logger.trace(f"{name}: unable to remove task (already removed?): {e}")
|
||||
|
||||
@@ -1 +1 @@
|
||||
-e ".[anthropic,google,langchain]"
|
||||
-e ".[anthropic,aws,google,langchain]"
|
||||
|
||||
@@ -40,6 +40,11 @@ from pipecat.services.anthropic.llm import (
|
||||
AnthropicLLMContext,
|
||||
AnthropicUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.aws.llm import (
|
||||
AWSBedrockAssistantContextAggregator,
|
||||
AWSBedrockLLMContext,
|
||||
AWSBedrockUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.google.llm import (
|
||||
GoogleAssistantContextAggregator,
|
||||
GoogleLLMContext,
|
||||
@@ -669,26 +674,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola
|
||||
AGGREGATOR_CLASS = LLMUserContextAggregator
|
||||
|
||||
|
||||
#
|
||||
# OpenAI
|
||||
#
|
||||
|
||||
|
||||
class TestOpenAIUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
|
||||
|
||||
class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
|
||||
#
|
||||
# Anthropic
|
||||
#
|
||||
@@ -724,6 +709,43 @@ class TestAnthropicAssistantContextAggregator(
|
||||
assert context.messages[index]["content"][0]["content"] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# AWS (Bedrock)
|
||||
#
|
||||
|
||||
|
||||
class TestAWSBedrockUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
|
||||
class TestAWSBedrockAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
assert context.messages[index]["content"][0]["toolResult"]["content"][0][
|
||||
"text"
|
||||
] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# Google
|
||||
#
|
||||
@@ -766,3 +788,23 @@ class TestGoogleAssistantContextAggregator(
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# OpenAI
|
||||
#
|
||||
|
||||
|
||||
class TestOpenAIUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
|
||||
|
||||
class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
@@ -9,6 +9,7 @@ import unittest
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -57,8 +58,8 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_system_frame(self):
|
||||
filter = FrameFilter(types=())
|
||||
frames_to_send = [UserStartedSpeakingFrame()]
|
||||
expected_down_frames = [UserStartedSpeakingFrame]
|
||||
frames_to_send = [StartInterruptionFrame()]
|
||||
expected_down_frames = [StartInterruptionFrame]
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
|
||||
@@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
@@ -174,3 +175,32 @@ class TestFunctionAdapters(unittest.TestCase):
|
||||
tools_def = self.tools_def
|
||||
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
|
||||
assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected
|
||||
|
||||
def test_bedrock_adapter(self):
|
||||
"""Test AWS Bedrock adapter format transformation."""
|
||||
expected = [
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given location",
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use.",
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city, e.g. San Francisco",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
assert AWSBedrockLLMAdapter().to_provider_tools_format(self.tools_def) == expected
|
||||
|
||||
Reference in New Issue
Block a user