Compare commits

..

1 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
0bf61d9e58 move observer proxies to FrameProcessor 2025-05-05 11:31:15 -07:00
68 changed files with 1252 additions and 4266 deletions

View File

@@ -5,71 +5,34 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.0.67] - 2025-05-07
## [Unreleased]
### Added
- Added `DebugLogObserver` for detailed frame logging with configurable
filtering by frame type and endpoint. This observer automatically extracts
and formats all frame data fields for debug logging.
- `UserImageRequestFrame.video_source` field has been added to request an image
from the desired video source.
- Added support for the AWS Nova Sonic speech-to-speech model with the new
`AWSNovaSonicLLMService`.
See https://docs.aws.amazon.com/nova/latest/userguide/speech.html.
Note that it requires Python >= 3.12 and `pip install pipecat-ai[aws-nova-sonic]`.
- Added new AWS services `AWSBedrockLLMService` and `AWSTranscribeSTTService`.
- Added `on_active_speaker_changed` event handler to the `DailyTransport` class.
- Added `enable_ssml_parsing` and `enable_logging` to `InputParams` in
`ElevenLabsTTSService`.
- Added support to `RimeHttpTTSService` for the `arcana` model.
### Changed
- Updated `ElevenLabsTTSService` to use the beta websocket API
(multi-stream-input). This new API supports context_ids and cancelling those
contexts, which greatly improves interruption handling.
- Observers `on_push_frame()` now take a single argument `FramePushed` instead
of multiple arguments.
- Updated the default voice for `DeepgramTTSService` to `aura-2-helena-en`.
### Removed
- `StartFrame.observer` has been replaced by `StartFrame.observers`. This allows
a user to pass all observers instead of a single proxy one.
### Deprecated
- `PollyTTSService` is now deprecated, use `AWSPollyTTSService` instead.
- Observer `on_push_frame(src, dst, frame, direction, timestamp)` is now
deprecated, use `on_push_frame(data: FramePushed)` instead.
### Fixed
- Fixed a `DailyTransport` issue that was causing issues when multiple audio or
video sources where being captured.
- Fixed a `UltravoxSTTService` issue that would cause the service to generate
all tokens as one word.
- Fixed a `PipelineTask` issue that would cause tasks to not be cancelled if
task was cancelled from outside of Pipecat.
- Fixed a `TaskManager` that was causing dangling tasks to be reported.
- Fixed an issue that could cause data to be sent to the transports when they
were still not ready.
- Remove custom audio tracks from `DailyTransport` before leaving.
### Removed
- Removed `CanonicalMetricsService` as it's no longer maintained.
## [0.0.66] - 2025-05-02
### Added

View File

@@ -49,18 +49,18 @@ You can connect to Pipecat from any platform using our official SDKs:
## 🧩 Available services
| Category | Services |
|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
| Category | Services |
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)

View File

@@ -10,6 +10,7 @@ pipecat-ai[anthropic]
pipecat-ai[assemblyai]
pipecat-ai[aws]
pipecat-ai[azure]
pipecat-ai[canonical]
pipecat-ai[cartesia]
pipecat-ai[cerebras]
pipecat-ai[deepseek]

161
examples/canonical-metrics/.gitignore vendored Normal file
View File

@@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
recordings/
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
runpod.toml

View File

@@ -0,0 +1,10 @@
FROM python:3.10-bullseye
RUN mkdir /app
COPY *.py /app/
COPY requirements.txt /app/
WORKDIR /app
RUN pip3 install -r requirements.txt
EXPOSE 7860
CMD ["python3", "server.py"]

View File

@@ -0,0 +1,66 @@
# Chatbot with canonical-metrics
This project implements a chatbot using a pipeline architecture that integrates audio processing, transcription, and a language model for conversational interactions. The chatbot operates within a daily communication environment, utilizing various services for text-to-speech and language model responses.
## Features
- **Audio Input and Output**: Captures microphone input and plays back audio responses.
- **Voice Activity Detection**: Utilizes Silero VAD to manage audio input intelligently.
- **Text-to-Speech**: Integrates ElevenLabs TTS service to convert text responses into audio.
- **Language Model Interaction**: Uses OpenAI's GPT-4 model to generate responses based on user input.
- **Transcription Services**: Captures and transcribes participant speech for analytics.
- **Metrics Collection**: Sends audio data for analysis via Canonical Metrics Service.
## Requirements
- Python 3.10+
- `python-dotenv`
- Additional libraries from the `pipecat` package.
## Setup
1. Clone the repository.
2. Install the required packages.
3. Set up environment variables for API keys:
- `OPENAI_API_KEY`
- `ELEVENLABS_API_KEY`
- `CANONICAL_API_KEY`
- `CANONICAL_API_URL`
4. Run the script.
## Usage
The chatbot introduces itself and engages in conversations, providing brief and creative responses. Designed for flexibility, it can support multiple languages with appropriate configuration.
## Events
- Participants joining or leaving the call are handled dynamically, adjusting the chatbot's behavior accordingly.
The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded.
## Get started
```python
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
cp env.example .env # and add your credentials
```
## Run the server
```bash
python server.py
```
Then, visit `http://localhost:7860/` in your browser to start a chatbot session.
## Build and test the Docker image
```
docker build -t chatbot .
docker run --env-file .env -p 7860:7860 chatbot
```

View File

@@ -0,0 +1,146 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import uuid
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import EndFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.services.canonical.metrics import CanonicalMetricsService
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Chatbot",
DailyParams(
audio_out_enabled=True,
audio_in_enabled=True,
video_out_enabled=False,
vad_analyzer=SileroVADAnalyzer(),
transcription_enabled=True,
#
# Spanish
#
# transcription_settings=DailyTranscriptionSettings(
# language="es",
# tier="nova",
# model="2-general"
# )
),
)
tts = ElevenLabsTTSService(
api_key=os.getenv("ELEVENLABS_API_KEY"),
#
# English
#
voice_id="cgSgspJ2msm6clMCkdW9",
#
# Spanish
#
# model="eleven_multilingual_v2",
# voice_id="gD1IexrzCvsXPHUuT0s3",
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
"role": "system",
#
# English
#
"content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself. Keep all your responses to 12 words or fewer.",
#
# Spanish
#
# "content": "Eres Chatbot, un amigable y útil robot. Tu objetivo es demostrar tus capacidades de una manera breve. Tus respuestas se convertiran a audio así que nunca no debes incluir caracteres especiales. Contesta a lo que el usuario pregunte de una manera creativa, útil y breve. Empieza por presentarte a ti mismo.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
"""
CanonicalMetrics uses AudioBufferProcessor under the hood to buffer the audio. On
call completion, CanonicalMetrics will send the audio buffer to Canonical for
analysis. Visit https://voice.canonical.chat to learn more.
"""
audio_buffer_processor = AudioBufferProcessor(num_channels=2)
canonical = CanonicalMetricsService(
audio_buffer_processor=audio_buffer_processor,
aiohttp_session=session,
api_key=os.getenv("CANONICAL_API_KEY"),
call_id=str(uuid.uuid4()),
assistant="pipecat-chatbot",
assistant_speaks_first=True,
context=context,
)
pipeline = Pipeline(
[
transport.input(), # microphone
context_aggregator.user(),
llm,
tts,
transport.output(),
canonical, # uploads audio buffer to Canonical AI for metrics
audio_buffer_processor, # captures audio into a buffer
context_aggregator.assistant(),
]
)
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await audio_buffer_processor.start_recording()
await transport.capture_participant_transcription(participant["id"])
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):
print(f"Participant left: {participant}")
await task.cancel()
@transport.event_handler("on_call_state_updated")
async def on_call_state_updated(transport, state):
if state == "left":
# Here we don't want to cancel, we just want to finish sending
# whatever is queued, so we use an EndFrame().
await task.queue_frame(EndFrame())
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,6 @@
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
DAILY_API_KEY=7df...
OPENAI_API_KEY=sk-PL...
ELEVENLABS_API_KEY=aeb...
CANONICAL_API_KEY=can...
CANONICAL_API_URL=

View File

@@ -0,0 +1,5 @@
python-dotenv
fastapi[all]
uvicorn
pipecat-ai[daily,openai,silero,elevenlabs,canonical]

View File

@@ -0,0 +1,55 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
import aiohttp
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper
async def configure(aiohttp_session: aiohttp.ClientSession):
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")
parser.add_argument(
"-u", "--url", type=str, required=False, help="URL of the Daily room to join"
)
parser.add_argument(
"-k",
"--apikey",
type=str,
required=False,
help="Daily API Key (needed to create an owner token for the room)",
)
args, unknown = parser.parse_known_args()
url = args.url or os.getenv("DAILY_SAMPLE_ROOM_URL")
key = args.apikey or os.getenv("DAILY_API_KEY")
if not url:
raise Exception(
"No Daily room specified. use the -u/--url option from the command line, or set DAILY_SAMPLE_ROOM_URL in your environment to specify a Daily room URL."
)
if not key:
raise Exception(
"No Daily API key specified. use the -k/--apikey option from the command line, or set DAILY_API_KEY in your environment to specify a Daily API key, available from https://dashboard.daily.co/developers."
)
daily_rest_helper = DailyRESTHelper(
daily_api_key=key,
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
aiohttp_session=aiohttp_session,
)
# Create a meeting token for the given room with an expiration 1 hour in
# the future.
expiry_time: float = 60 * 60
token = await daily_rest_helper.get_token(url, expiry_time)
return (url, token)

View File

@@ -0,0 +1,139 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
import subprocess
from contextlib import asynccontextmanager
import aiohttp
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
MAX_BOTS_PER_ROOM = 1
# Bot sub-process dict for status reporting and concurrency control
bot_procs = {}
daily_helpers = {}
load_dotenv(override=True)
def cleanup():
# Clean up function, just to be extra safe
for entry in bot_procs.values():
proc = entry[0]
proc.terminate()
proc.wait()
@asynccontextmanager
async def lifespan(app: FastAPI):
aiohttp_session = aiohttp.ClientSession()
daily_helpers["rest"] = DailyRESTHelper(
daily_api_key=os.getenv("DAILY_API_KEY", ""),
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
aiohttp_session=aiohttp_session,
)
yield
await aiohttp_session.close()
cleanup()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def start_agent(request: Request):
print(f"!!! Creating room")
room = await daily_helpers["rest"].create_room(DailyRoomParams())
print(f"!!! Room URL: {room.url}")
# Ensure the room property is present
if not room.url:
raise HTTPException(
status_code=500,
detail="Missing 'room' property in request data. Cannot start agent without a target room!",
)
# Check if there is already an existing process running in this room
num_bots_in_room = sum(
1 for proc in bot_procs.values() if proc[1] == room.url and proc[0].poll() is None
)
if num_bots_in_room >= MAX_BOTS_PER_ROOM:
raise HTTPException(status_code=500, detail=f"Max bot limited reach for room: {room.url}")
# Get the token for the room
token = await daily_helpers["rest"].get_token(room.url)
if not token:
raise HTTPException(status_code=500, detail=f"Failed to get token for room: {room.url}")
# Spawn a new agent, and join the user session
# Note: this is mostly for demonstration purposes (refer to 'deployment' in README)
try:
proc = subprocess.Popen(
[f"python3 -m bot -u {room.url} -t {token}"],
shell=True,
bufsize=1,
cwd=os.path.dirname(os.path.abspath(__file__)),
)
bot_procs[proc.pid] = (proc, room.url)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to start subprocess: {e}")
return RedirectResponse(room.url)
@app.get("/status/{pid}")
def get_status(pid: int):
# Look up the subprocess
proc = bot_procs.get(pid)
# If the subprocess doesn't exist, return an error
if not proc:
raise HTTPException(status_code=404, detail=f"Bot with process id: {pid} not found")
# Check the status of the subprocess
if proc[0].poll() is None:
status = "running"
else:
status = "finished"
return JSONResponse({"bot_id": pid, "status": status})
if __name__ == "__main__":
import uvicorn
default_host = os.getenv("HOST", "0.0.0.0")
default_port = int(os.getenv("FAST_API_PORT", "7860"))
parser = argparse.ArgumentParser(description="Daily Storyteller FastAPI server")
parser.add_argument("--host", type=str, default=default_host, help="Host address")
parser.add_argument("--port", type=int, default=default_port, help="Port number")
parser.add_argument("--reload", action="store_true", help="Reload code on change")
config = parser.parse_args()
uvicorn.run(
"server:app",
host=config.host,
port=config.port,
reload=config.reload,
)

View File

@@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import aiohttp
@@ -22,23 +21,44 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
# Check if we're in local development mode
LOCAL_RUN = os.getenv("LOCAL_RUN")
if LOCAL_RUN:
import asyncio
import webbrowser
try:
from local_runner import configure
except ImportError:
logger.error("Could not import local_runner module. Local development mode may not work.")
# Load environment variables
load_dotenv(override=True)
# Check if we're in local development mode
LOCAL_RUN = os.getenv("LOCAL_RUN")
async def main(transport: DailyTransport):
async def main(room_url: str, token: str):
"""Main pipeline setup and execution function.
Args:
transport: The DailyTransport object for the bot
room_url: The Daily room URL
token: The Daily room token
"""
logger.debug("Starting bot")
logger.debug("Starting bot in room: {}", room_url)
transport = DailyTransport(
room_url,
token,
"bot",
DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
transcription_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"), voice_id="71a7ad14-091c-4e8e-a314-022ece01c121"
api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22"
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
@@ -106,25 +126,10 @@ async def bot(args: DailySessionArguments):
body: The configuration object from the request body
session_id: The session ID for logging
"""
from pipecat.audio.filters.krisp_filter import KrispFilter
logger.info(f"Bot process initialized {args.room_url} {args.token}")
transport = DailyTransport(
args.room_url,
args.token,
"Pipecat Bot",
DailyParams(
audio_in_enabled=True,
audio_in_filter=None if LOCAL_RUN else KrispFilter(),
audio_out_enabled=True,
transcription_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
try:
await main(transport)
await main(args.room_url, args.token)
logger.info("Bot process completed")
except Exception as e:
logger.exception(f"Error in bot process: {str(e)}")
@@ -132,27 +137,18 @@ async def bot(args: DailySessionArguments):
# Local development functions
async def local_daily():
async def local_main():
"""Function for local development testing."""
from local_runner import configure
try:
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Pipecat Bot",
DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
transcription_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
await main(transport)
logger.warning("_")
logger.warning("_")
logger.warning(f"Talk to your voice agent here: {room_url}")
logger.warning("_")
logger.warning("_")
webbrowser.open(room_url)
await main(room_url, token)
except Exception as e:
logger.exception(f"Error in local development mode: {e}")
@@ -160,6 +156,6 @@ async def local_daily():
# Local development entry point
if LOCAL_RUN and __name__ == "__main__":
try:
asyncio.run(local_daily())
asyncio.run(local_main())
except Exception as e:
logger.exception(f"Failed to run in local mode: {e}")

View File

@@ -1,4 +1,2 @@
CARTESIA_API_KEY=
OPENAI_API_KEY=
# Local dev only
DAILY_API_KEY=
OPENAI_API_KEY=

View File

@@ -7,7 +7,6 @@
import os
import aiohttp
from fastapi import HTTPException
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams

View File

@@ -1,8 +1,6 @@
agent_name = "my-first-agent"
image = "your-username/my-first-agent:0.1"
image_credentials = "your-dockerhub-creds"
secret_set = "my-first-agent-secrets"
enable_krisp = true
[scaling]
min_instances = 0

View File

@@ -47,7 +47,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
live_options=LiveOptions(vad_events=True, utterance_end_ms="1000"),
)
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en")
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))

View File

@@ -39,7 +39,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en")
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))

View File

@@ -5,6 +5,7 @@
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
@@ -14,9 +15,9 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.aws.stt import AWSTranscribeSTTService
from pipecat.services.aws.tts import AWSPollyTTSService
from pipecat.services.aws.tts import PollyTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -36,19 +37,17 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
),
)
stt = AWSTranscribeSTTService()
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = AWSPollyTTSService(
region="us-west-2", # only specific regions support generative TTS
voice_id="Joanna",
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
tts = PollyTTSService(
api_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
region=os.getenv("AWS_REGION"),
voice_id="Amy",
params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"),
)
llm = AWSBedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
@@ -86,7 +85,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")

View File

@@ -1,139 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.aws.stt import AWSTranscribeSTTService
from pipecat.services.aws.tts import AWSPollyTTSService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
async def fetch_weather_from_api(params: FunctionCallParams):
await params.result_callback({"conditions": "nice", "temperature": "75"})
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
stt = AWSTranscribeSTTService()
tts = AWSPollyTTSService(
region="us-west-2", # only specific regions support generative TTS
voice_id="Joanna",
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
)
llm = AWSBedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function("get_current_weather", fetch_weather_from_api)
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
tools = ToolsSchema(standard_tools=[weather_function])
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()

View File

@@ -1,267 +0,0 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import asyncio
import glob
import json
import os
from datetime import datetime
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws_nova_sonic.aws import AWSNovaSonicLLMService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
BASE_FILENAME = "/tmp/pipecat_conversation_"
async def fetch_weather_from_api(params: FunctionCallParams):
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
await params.result_callback(
{
"conditions": "nice",
"temperature": temperature,
"format": params.arguments["format"],
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
}
)
async def get_saved_conversation_filenames(params: FunctionCallParams):
# Construct the full pattern including the BASE_FILENAME
full_pattern = f"{BASE_FILENAME}*.json"
# Use glob to find all matching files
matching_files = glob.glob(full_pattern)
logger.debug(f"matching files: {matching_files}")
await params.result_callback({"filenames": matching_files})
# async def get_saved_conversation_filenames(
# function_name, tool_call_id, args, llm, context, result_callback
# ):
# pattern = re.compile(re.escape(BASE_FILENAME) + "\\d{8}_\\d{6}\\.json$")
# matching_files = []
# for filename in os.listdir("."):
# if pattern.match(filename):
# matching_files.append(filename)
# await result_callback({"filenames": matching_files})
async def save_conversation(params: FunctionCallParams):
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
filename = f"{BASE_FILENAME}{timestamp}.json"
try:
with open(filename, "w") as file:
messages = params.context.get_messages_for_persistent_storage()
# remove the last few messages. in reverse order, they are:
# - the in progress save tool call
# - the invocation of the save tool call
# - the user ask to save (which may encompass one or more messages)
# the simplest thing to do is to pop messages until the last one is an assistant
# response
while messages and not (
messages[-1].get("role") == "assistant" and "content" in messages[-1]
):
messages.pop()
if messages: # we never expect this to be empty
logger.debug(
f"writing conversation to {filename}\n{json.dumps(messages, indent=4)}"
)
json.dump(messages, file, indent=2)
await params.result_callback({"success": True})
except Exception as e:
await params.result_callback({"success": False, "error": str(e)})
async def load_conversation(params: FunctionCallParams):
async def _reset():
filename = params.arguments["filename"]
logger.debug(f"loading conversation from {filename}")
try:
with open(filename, "r") as file:
messages = json.load(file)
messages.append(
{
"role": "user",
"content": f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}",
}
)
params.context.set_messages(messages)
await params.llm.reset_conversation()
await params.llm.trigger_assistant_response()
except Exception as e:
await params.result_callback({"success": False, "error": str(e)})
asyncio.create_task(_reset())
get_current_weather_tool = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
save_conversation_tool = FunctionSchema(
name="save_conversation",
description="Save the current conversation. Use this function to persist the current conversation to external storage.",
properties={},
required=[],
)
get_saved_conversation_filenames_tool = FunctionSchema(
name="get_saved_conversation_filenames",
description="Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
properties={},
required=[],
)
load_conversation_tool = FunctionSchema(
name="load_conversation",
description="Load a conversation history. Use this function to load a conversation history into the current session.",
properties={
"filename": {
"type": "string",
"description": "The filename of the conversation history to load.",
}
},
required=["filename"],
)
tools = ToolsSchema(
standard_tools=[
get_current_weather_tool,
save_conversation_tool,
get_saved_conversation_filenames_tool,
load_conversation_tool,
]
)
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
),
)
# Specify initial system instruction.
# HACK: note that, for now, we need to inject a special bit of text into this instruction to
# allow the first assistant response to be programmatically triggered (which happens in the
# on_client_connected handler, below)
system_instruction = (
"You are a friendly assistant. The user and you will engage in a spoken dialog exchanging "
"the transcripts of a natural real-time conversation. Keep your responses short, generally "
"two or three sentences for chatty scenarios. "
f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}"
)
llm = AWSNovaSonicLLMService(
secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region
voice_id="tiffany", # matthew, tiffany, amy
# you could choose to pass instruction here rather than via context
# system_instruction=system_instruction,
# you could choose to pass tools here rather than via context
# tools=tools
)
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("save_conversation", save_conversation)
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
llm.register_function("load_conversation", load_conversation)
context = OpenAILLMContext(
messages=[
{"role": "system", "content": f"{system_instruction}"},
],
tools=tools,
)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(),
llm, # LLM
transport.output(), # Transport bot output
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
# HACK: for now, we need this special way of triggering the first assistant response in AWS
# Nova Sonic. Note that this trigger requires a special corresponding bit of text in the
# system instruction. In the future, simply queueing the context frame should be sufficient.
await llm.trigger_assistant_response()
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()

View File

@@ -14,26 +14,18 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EndFrame,
StartInterruptionFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
from pipecat.observers.loggers.llm_log_observer import LLMLogObserver
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -41,7 +33,7 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
class CustomObserver(BaseObserver):
class DebugObserver(BaseObserver):
"""Observer to log interruptions and bot speaking events to the console.
Logs all frame instances of:
@@ -66,7 +58,7 @@ class CustomObserver(BaseObserver):
# Create direction arrow
arrow = "" if direction == FrameDirection.DOWNSTREAM else ""
if isinstance(frame, StartInterruptionFrame) and isinstance(src, BaseOutputTransport):
if isinstance(frame, StartInterruptionFrame):
logger.info(f"⚡ INTERRUPTION START: {src} {arrow} {dst} at {time_sec:.2f}s")
elif isinstance(frame, BotStartedSpeakingFrame):
logger.info(f"🤖 BOT START SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s")
@@ -125,17 +117,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
observers=[
CustomObserver(),
LLMLogObserver(),
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION),
UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE),
EndFrame: None,
}
),
],
observers=[DebugObserver(), LLMLogObserver()],
)
@transport.event_handler("on_client_connected")

View File

@@ -1,173 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from datetime import datetime
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws_nova_sonic import AWSNovaSonicLLMService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
# Load environment variables
load_dotenv(override=True)
async def fetch_weather_from_api(params: FunctionCallParams):
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
await params.result_callback(
{
"conditions": "nice",
"temperature": temperature,
"format": params.arguments["format"],
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
}
)
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
required=["location", "format"],
)
# Create tools schema
tools = ToolsSchema(standard_tools=[weather_function])
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
# Initialize the SmallWebRTCTransport with the connection
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_in_sample_rate=16000,
audio_out_enabled=True,
camera_in_enabled=False,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
),
)
# Specify initial system instruction.
# HACK: note that, for now, we need to inject a special bit of text into this instruction to
# allow the first assistant response to be programmatically triggered (which happens in the
# on_client_connected handler, below)
system_instruction = (
"You are a friendly assistant. The user and you will engage in a spoken dialog exchanging "
"the transcripts of a natural real-time conversation. Keep your responses short, generally "
"two or three sentences for chatty scenarios. "
f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}"
)
# Create the AWS Nova Sonic LLM service
llm = AWSNovaSonicLLMService(
secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region
voice_id="tiffany", # matthew, tiffany, amy
# you could choose to pass instruction here rather than via context
# system_instruction=system_instruction
# you could choose to pass tools here rather than via context
# tools=tools
)
# Register function for function calls
# you can either register a single function for all function calls, or specific functions
# llm.register_function(None, fetch_weather_from_api)
llm.register_function("get_current_weather", fetch_weather_from_api)
# Set up context and context management.
# AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to
# what's expected by Nova Sonic.
context = OpenAILLMContext(
messages=[
{"role": "system", "content": f"{system_instruction}"},
{
"role": "user",
"content": "Tell me a fun fact!",
},
],
tools=tools,
)
context_aggregator = llm.create_context_aggregator(context)
# Build the pipeline
pipeline = Pipeline(
[
transport.input(),
context_aggregator.user(),
llm,
transport.output(),
context_aggregator.assistant(),
]
)
# Configure the pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)
# Handle client connection event
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
# HACK: for now, we need this special way of triggering the first assistant response in AWS
# Nova Sonic. Note that this trigger requires a special corresponding bit of text in the
# system instruction. In the future, simply queueing the context frame should be sufficient.
await llm.trigger_assistant_response()
# Handle client disconnection events
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
# Run the pipeline
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()

View File

@@ -10,16 +10,12 @@ import subprocess
from contextlib import asynccontextmanager
import aiohttp
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
# Load environment variables from .env file
load_dotenv(override=True)
MAX_BOTS_PER_ROOM = 1
# Bot sub-process dict for status reporting and concurrency control

View File

@@ -10,16 +10,12 @@ import subprocess
from contextlib import asynccontextmanager
import aiohttp
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
# Load environment variables from .env file
load_dotenv(override=True)
MAX_BOTS_PER_ROOM = 1
# Bot sub-process dict for status reporting and concurrency control

View File

@@ -41,13 +41,13 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "assemblyai~=0.37.0" ]
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ]
aws = [ "boto3~=1.37.16" ]
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
canonical = [ "aiofiles~=24.1.0" ]
cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
cerebras = []
deepseek = []
daily = [ "daily-python~=0.18.2" ]
daily = [ "daily-python~=0.18.1" ]
deepgram = [ "deepgram-sdk~=3.8.0" ]
elevenlabs = [ "websockets~=13.1" ]
fal = [ "fal-client~=0.5.9" ]
@@ -97,7 +97,6 @@ where = ["src"]
[tool.setuptools.package-data]
"pipecat" = ["py.typed"]
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
[tool.pytest.ini_options]
addopts = "--verbose"

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema

View File

@@ -1,40 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import json
from typing import Any, Dict, List
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
class AWSNovaSonicLLMAdapter(BaseLLMAdapter):
@staticmethod
def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]:
return {
"toolSpec": {
"name": function.name,
"description": function.description,
"inputSchema": {
"json": json.dumps(
{
"type": "object",
"properties": function.properties,
"required": function.required,
}
)
},
}
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
"""Converts function schemas to AWS Nova Sonic function-calling format.
:return: AWS Nova Sonic formatted function call definition.
"""
functions_schema = tools_schema.standard_tools
return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema]

View File

@@ -1,38 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Any, Dict, List
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
class AWSBedrockLLMAdapter(BaseLLMAdapter):
@staticmethod
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
return {
"toolSpec": {
"name": function.name,
"description": function.description,
"inputSchema": {
"json": {
"type": "object",
"properties": function.properties,
"required": function.required,
},
},
}
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
"""Converts function schemas to Bedrock's function-calling format.
:return: Bedrock formatted function call definition.
"""
functions_schema = tools_schema.standard_tools
return [self._to_bedrock_function_format(func) for func in functions_schema]

View File

@@ -16,6 +16,7 @@ from typing import (
Literal,
Mapping,
Optional,
Sequence,
Tuple,
)
@@ -77,8 +78,8 @@ class Frame:
@dataclass
class SystemFrame(Frame):
"""A frame that takes higher priority than other frames. System frames are
handled in order and are not affected by user interruptions.
"""System frames are frames that are not internally queued by any of the
frame processors and should be processed immediately.
"""
@@ -87,9 +88,8 @@ class SystemFrame(Frame):
@dataclass
class DataFrame(Frame):
"""A frame that is processed in order and usually contains data such as LLM
context, text, audio or images. Data frames are cancelled by user
interruptions.
"""Data frames are frames that will be processed in order and usually
contain data such as LLM context, text, audio or images.
"""
@@ -98,9 +98,9 @@ class DataFrame(Frame):
@dataclass
class ControlFrame(Frame):
"""A frame that, as data frames, is processed in order and usually contains
control information such as update settings or to end the pipeline after
everything is flushed. Control frames are cancelled by user interruptions.
"""Control frames are frames that, similar to data frames, will be processed
in order and usually contain control information such as frames to update
settings or to end the pipeline.
"""
@@ -450,12 +450,12 @@ class StartFrame(SystemFrame):
clock: BaseClock
task_manager: BaseTaskManager
observers: Sequence["BaseObserver"]
audio_in_sample_rate: int = 16000
audio_out_sample_rate: int = 24000
allow_interruptions: bool = False
enable_metrics: bool = False
enable_usage_metrics: bool = False
observer: Optional["BaseObserver"] = None
report_only_initial_ttfb: bool = False
@@ -691,7 +691,7 @@ class FunctionCallResultFrame(SystemFrame):
@dataclass
class STTMuteFrame(SystemFrame):
"""A frame to mute/unmute the STT service."""
"""System frame to mute/unmute the STT service."""
mute: bool
@@ -716,10 +716,9 @@ class UserImageRequestFrame(SystemFrame):
context: Optional[Any] = None
function_name: Optional[str] = None
tool_call_id: Optional[str] = None
video_source: Optional[str] = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, video_source: {self.video_source}, function: {self.function_name}, request: {self.tool_call_id})"
return f"{self.name}(user: {self.user_id}, function: {self.function_name}, request: {self.tool_call_id})"
@dataclass
@@ -797,7 +796,7 @@ class EndFrame(ControlFrame):
should be shut down. If the transport receives this frame, it will stop
sending frames to its output channel(s) and close all its threads. Note,
that this is a control frame, which means it will received in the order it
was sent.
was sent (unline system frames).
"""

View File

@@ -1,218 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from dataclasses import fields, is_dataclass
from enum import Enum, auto
from typing import Dict, Optional, Set, Tuple, Type, Union
from loguru import logger
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.processors.frame_processor import FrameDirection
class FrameEndpoint(Enum):
"""Specifies which endpoint (source or destination) to filter on."""
SOURCE = auto()
DESTINATION = auto()
class DebugLogObserver(BaseObserver):
"""Observer that logs frame activity with detailed content to the console.
Automatically extracts and formats data from any frame type, making it useful
for debugging pipeline behavior without needing frame-specific observers.
Args:
frame_types: Optional tuple of frame types to log, or a dict with frame type
filters. If None, logs all frame types.
exclude_fields: Optional set of field names to exclude from logging.
Examples:
Log all frames from all services:
```python
observers = DebugLogObserver()
```
Log specific frame types from any source/destination:
```python
from pipecat.frames.frames import TranscriptionFrame, InterimTranscriptionFrame
observers = DebugLogObserver(frame_types=(TranscriptionFrame, InterimTranscriptionFrame))
```
Log frames with specific source/destination filters:
```python
from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
from pipecat.transports.base_output_transport import BaseOutputTransport
from pipecat.services.stt_service import STTService
observers = DebugLogObserver(frame_types={
# Only log StartInterruptionFrame when source is BaseOutputTransport
StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
# Only log UserStartedSpeakingFrame when destination is STTService
UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION),
# Log LLMTextFrame regardless of source or destination type
LLMTextFrame: None
})
```
"""
def __init__(
self,
frame_types: Optional[
Union[Tuple[Type[Frame], ...], Dict[Type[Frame], Optional[Tuple[Type, FrameEndpoint]]]]
] = None,
exclude_fields: Optional[Set[str]] = None,
):
"""Initialize the debug log observer.
Args:
frame_types: Tuple of frame types to log, or a dict mapping frame types to
filter configurations. Filter configs can be:
- None to log all instances of the frame type
- A tuple of (service_type, endpoint) to filter on a specific service
and endpoint (SOURCE or DESTINATION)
If None is provided instead of a tuple/dict, log all frames.
exclude_fields: Set of field names to exclude from logging. If None, only binary
data fields are excluded.
"""
# Process frame filters
self.frame_filters = {}
if frame_types is not None:
if isinstance(frame_types, tuple):
# Tuple of frame types - log all instances
self.frame_filters = {frame_type: None for frame_type in frame_types}
else:
# Dict of frame types with filters
self.frame_filters = frame_types
# By default, exclude binary data fields that would clutter logs
self.exclude_fields = (
exclude_fields
if exclude_fields is not None
else {
"audio", # Skip binary audio data
"image", # Skip binary image data
"images", # Skip lists of images
}
)
def _format_value(self, value):
"""Format a value for logging.
Args:
value: The value to format.
Returns:
str: A string representation of the value suitable for logging.
"""
if value is None:
return "None"
elif isinstance(value, str):
return f"{value!r}"
elif isinstance(value, (list, tuple)):
if len(value) == 0:
return "[]"
if isinstance(value[0], dict) and len(value) > 3:
# For message lists, just show count
return f"{len(value)} items"
return str(value)
elif isinstance(value, (bytes, bytearray)):
return f"{len(value)} bytes"
elif hasattr(value, "get_messages_for_logging") and callable(
getattr(value, "get_messages_for_logging")
):
# Special case for OpenAI context
return f"{value.__class__.__name__} with messages: {value.get_messages_for_logging()}"
else:
return str(value)
def _should_log_frame(self, frame, src, dst):
"""Determine if a frame should be logged based on filters.
Args:
frame: The frame being processed
src: The source component
dst: The destination component
Returns:
bool: True if the frame should be logged, False otherwise
"""
# If no filters, log all frames
if not self.frame_filters:
return True
# Check if this frame type is in our filters
for frame_type, filter_config in self.frame_filters.items():
if isinstance(frame, frame_type):
# If filter is None, log all instances of this frame type
if filter_config is None:
return True
# Otherwise, check the specific filter
service_type, endpoint = filter_config
if endpoint == FrameEndpoint.SOURCE:
return isinstance(src, service_type)
elif endpoint == FrameEndpoint.DESTINATION:
return isinstance(dst, service_type)
return False
async def on_push_frame(self, data: FramePushed):
"""Process a frame being pushed into the pipeline.
Logs frame details to the console with all relevant fields and values.
Args:
data: Event data containing the frame, source, destination, direction, and timestamp.
"""
src = data.source
dst = data.destination
frame = data.frame
direction = data.direction
timestamp = data.timestamp
# Check if we should log this frame
if not self._should_log_frame(frame, src, dst):
return
# Format direction arrow
arrow = "" if direction == FrameDirection.DOWNSTREAM else ""
time_sec = timestamp / 1_000_000_000
class_name = frame.__class__.__name__
# Build frame representation
frame_details = []
# If dataclass, extract fields
if is_dataclass(frame):
for field in fields(frame):
if field.name in self.exclude_fields:
continue
value = getattr(frame, field.name)
if value is None:
continue
formatted_value = self._format_value(value)
frame_details.append(f"{field.name}: {formatted_value}")
# Format the message
if frame_details:
details = ", ".join(frame_details)
message = f"{class_name} {details} at {time_sec:.2f}s"
else:
message = f"{class_name} at {time_sec:.2f}s"
# Log the message
logger.debug(f"{src} {arrow} {dst}: {message}")

View File

@@ -7,6 +7,7 @@
from loguru import logger
from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
@@ -16,7 +17,7 @@ from pipecat.frames.frames import (
)
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import LLMService

View File

@@ -7,10 +7,12 @@
from loguru import logger
from pipecat.frames.frames import (
Frame,
InterimTranscriptionFrame,
TranscriptionFrame,
)
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.stt_service import STTService

View File

@@ -32,7 +32,6 @@ from pipecat.metrics.metrics import ProcessingMetricsData, TTFBMetricsData
from pipecat.observers.base_observer import BaseObserver
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.base_task import BaseTask
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
@@ -184,6 +183,7 @@ class PipelineTask(BaseTask):
self._idle_timeout_secs = idle_timeout_secs
self._idle_timeout_frames = idle_timeout_frames
self._cancel_on_idle_timeout = cancel_on_idle_timeout
self._observers = observers
if self._params.observers:
import warnings
@@ -193,7 +193,7 @@ class PipelineTask(BaseTask):
"Field 'observers' is deprecated, use the 'observers' parameter instead.",
DeprecationWarning,
)
observers = self._params.observers
self._observers = self._params.observers
self._finished = False
# This queue receives frames coming from the pipeline upstream.
@@ -229,11 +229,6 @@ class PipelineTask(BaseTask):
# PipelineTask and its frame processors.
self._task_manager = task_manager or TaskManager()
# The task observer acts as a proxy to the provided observers. This way,
# we only need to pass a single observer (using the StartFrame) which
# then just acts as a proxy.
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
# These events can be used to check which frames make it to the source
# or sink processors. Instead of calling the event handlers for every
# frame the user needs to specify which events they are interested
@@ -286,7 +281,12 @@ class PipelineTask(BaseTask):
async def cancel(self):
"""Stops the running pipeline immediately."""
logger.debug(f"Canceling pipeline task {self}")
await self._cancel()
# Make sure everything is cleaned up downstream. This is sent
# out-of-band from the main streaming task which is what we want since
# we want to cancel right away.
await self._source.push_frame(CancelFrame())
# Only cancel the push task. Everything else will be cancelled in run().
await self._task_manager.cancel_task(self._process_push_task)
async def run(self):
"""Starts and manages the pipeline execution until completion or cancellation."""
@@ -304,17 +304,11 @@ class PipelineTask(BaseTask):
# well, because you get a CancelledError in every place you are
# awaiting a task.
pass
finally:
# It's possibe that we get an asyncio.CancelledError from the
# outside, if so we need to make sure everything gets cancelled
# properly.
if cleanup_pipeline:
await self._cancel()
await self._cancel_tasks()
await self._cleanup(cleanup_pipeline)
if self._check_dangling_tasks:
self._print_dangling_tasks()
self._finished = True
await self._cancel_tasks()
await self._cleanup(cleanup_pipeline)
if self._check_dangling_tasks:
self._print_dangling_tasks()
self._finished = True
async def queue_frame(self, frame: Frame):
"""Queue a single frame to be pushed down the pipeline.
@@ -337,14 +331,6 @@ class PipelineTask(BaseTask):
for frame in frames:
await self.queue_frame(frame)
async def _cancel(self):
# Make sure everything is cleaned up downstream. This is sent
# out-of-band from the main streaming task which is what we want since
# we want to cancel right away.
await self._source.push_frame(CancelFrame())
# Only cancel the push task. Everything else will be cancelled in run().
await self._task_manager.cancel_task(self._process_push_task)
async def _create_tasks(self):
self._process_up_task = self._task_manager.create_task(
self._process_up_queue(), f"{self}::_process_up_queue"
@@ -356,8 +342,6 @@ class PipelineTask(BaseTask):
self._process_push_queue(), f"{self}::_process_push_queue"
)
await self._observer.start()
return self._process_push_task
def _maybe_start_heartbeat_tasks(self):
@@ -376,8 +360,6 @@ class PipelineTask(BaseTask):
)
async def _cancel_tasks(self):
await self._observer.stop()
await self._task_manager.cancel_task(self._process_up_task)
await self._task_manager.cancel_task(self._process_down_task)
@@ -434,7 +416,7 @@ class PipelineTask(BaseTask):
audio_out_sample_rate=self._params.audio_out_sample_rate,
enable_metrics=self._params.enable_metrics,
enable_usage_metrics=self._params.enable_usage_metrics,
observer=self._observer,
observers=self._observers,
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
)
start_frame.metadata = self._params.start_metadata

View File

@@ -1,95 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import inspect
from typing import List
from attr import dataclass
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.utils.asyncio import BaseTaskManager
@dataclass
class Proxy:
"""This is the data we receive from the main observer and that we put into
a queue for later processing.
"""
queue: asyncio.Queue
task: asyncio.Task
observer: BaseObserver
class TaskObserver(BaseObserver):
"""This is a pipeline frame observer that is meant to be used as a proxy to
the user provided observers. That is, this is the observer that should be
passed to the frame processors. Then, every time a frame is pushed this
observer will call all the observers registered to the pipeline task.
This observer makes sure that passing frames to observers doesn't block the
pipeline by creating a queue and a task for each user observer. When a frame
is received, it will be put in a queue for efficiency and later processed by
each task.
"""
def __init__(self, *, observers: List[BaseObserver] = [], task_manager: BaseTaskManager):
self._observers = observers
self._task_manager = task_manager
self._proxies: List[Proxy] = []
async def start(self):
"""Starts all proxy observer tasks."""
self._proxies = self._create_proxies(self._observers)
async def stop(self):
"""Stops all proxy observer tasks."""
for proxy in self._proxies:
await self._task_manager.cancel_task(proxy.task)
async def on_push_frame(self, data: FramePushed):
for proxy in self._proxies:
await proxy.queue.put(data)
def _create_proxies(self, observers) -> List[Proxy]:
proxies = []
for observer in observers:
queue = asyncio.Queue()
task = self._task_manager.create_task(
self._proxy_task_handler(queue, observer),
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
)
proxy = Proxy(queue=queue, task=task, observer=observer)
proxies.append(proxy)
return proxies
async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver):
warning_reported = False
while True:
data = await queue.get()
signature = inspect.signature(observer.on_push_frame)
if len(signature.parameters) > 1:
if not warning_reported:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Observer `on_push_frame(source, destination, frame, direction, timestamp)` is deprecated, us `on_push_frame(data: FramePushed)` instead.",
DeprecationWarning,
)
warning_reported = True
await observer.on_push_frame(
data.src, data.dst, data.frame, data.direction, data.timestamp
)
else:
await observer.on_push_frame(data)
queue.task_done()

View File

@@ -5,9 +5,9 @@
#
import asyncio
from dataclasses import dataclass
import inspect
from enum import Enum
from typing import Awaitable, Callable, Coroutine, Optional
from typing import Awaitable, Callable, Coroutine, List, Optional, Sequence
from loguru import logger
@@ -22,9 +22,9 @@ from pipecat.frames.frames import (
SystemFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
from pipecat.observers.base_observer import FramePushed
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
from pipecat.utils.base_object import BaseObject
@@ -33,49 +33,50 @@ class FrameDirection(Enum):
UPSTREAM = 2
@dataclass
class FrameProcessorQueueItem:
frame: Frame
direction: FrameDirection
callback: Optional[Callable[["FrameProcessor", Frame, FrameDirection], Awaitable[None]]]
class FrameProcessorQueue:
def __init__(self):
class FrameObserverProxy:
def __init__(self, task_manager: TaskManager, observer: BaseObserver) -> None:
self._task_manager = task_manager
self._observer = observer
self._queue = asyncio.Queue()
self._urgent_queue = asyncio.Queue()
self._event = asyncio.Event()
self._task: Optional[asyncio.Task] = None
self._warning_reported = False
async def put(self, item: FrameProcessorQueueItem):
if isinstance(item.frame, SystemFrame):
await self._urgent_queue.put(item)
else:
await self._queue.put(item)
self._event.set()
def start(self):
if not self._task:
self._task = self._task_manager.create_task(
self._observer_task_handler(), f"ObserverProxy::{self._observer.__class__.__name__}"
)
async def get(self) -> FrameProcessorQueueItem:
# Wait for an item in any of the queues.
await self._event.wait()
async def stop(self):
if self._task:
await self._task_manager.cancel_task(self._task)
self._task = None
if self._urgent_queue.empty():
item = await self._queue.get()
async def observe(self, data: FramePushed):
await self._queue.put(data)
async def _observer_task_handler(self):
while True:
data = await self._queue.get()
signature = inspect.signature(self._observer.on_push_frame)
if len(signature.parameters) > 1:
if not self._warning_reported:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Observer `on_push_frame(source, destination, frame, direction, timestamp)` is deprecated, use `on_push_frame(data: FramePushed)` instead.",
DeprecationWarning,
)
self._warning_reported = True
await self._observer.on_push_frame(
data.source, data.destination, data.frame, data.direction, data.timestamp
)
else:
await self._observer.on_push_frame(data)
self._queue.task_done()
else:
item = await self._urgent_queue.get()
self._urgent_queue.task_done()
# Clear the event only if all queues are empty.
if self._queue.empty() and self._urgent_queue.empty():
self._event.clear()
return item
def clear(self):
self._queue = asyncio.Queue()
# Clear the event only if all queues are empty.
if self._queue.empty() and self._urgent_queue.empty():
self._event.clear()
class FrameProcessor(BaseObject):
@@ -102,7 +103,7 @@ class FrameProcessor(BaseObject):
self._enable_metrics = False
self._enable_usage_metrics = False
self._report_only_initial_ttfb = False
self._observer = None
self._observer_proxies: List[FrameObserverProxy] = []
# Cancellation is done through CancelFrame (a system frame). This could
# cause other events being triggered (e.g. closing a transport) which
@@ -115,21 +116,24 @@ class FrameProcessor(BaseObject):
self._metrics = metrics or FrameProcessorMetrics()
self._metrics.set_processor_name(self.name)
# Processors receive frames on a streaming queue which are then
# processed by a streaming task. This guarantees that all frames are
# processed in the same task. By default, the streaming queue is
# processed immediately but it may block if `pause_processing_frames()`
# Processors have an input queue. The input queue will be processed
# immediately (default) or it will block if `pause_processing_frames()`
# is called. To resume processing frames we need to call
# `resume_processing_frames()` which will wake up the event.
self.__should_block_frames = False
self.__streaming_event = asyncio.Event()
self.__streaming_queue = FrameProcessorQueue()
self.__streaming_frame_task: Optional[asyncio.Task] = None
self.__input_event = asyncio.Event()
self.__input_frame_task: Optional[asyncio.Task] = None
self.__process_queue = asyncio.Queue()
self.__process_task: Optional[asyncio.Task] = None
self.__process_urgent_queue = asyncio.Queue()
self.__process_urgent_task: Optional[asyncio.Task] = None
# Every processor in Pipecat should only output frames from a single
# task. This avoids problems like audio overlapping. System frames are
# the exception to this rule.
self.__push_frame_task: Optional[asyncio.Task] = None
# The observers task will push observed frames to the observers'
# proxies. This task avoids the pipeline to block. Then, there is one
# proxy per observer and each proxy has its own task. This way each
# observer is independent to the rest.
self.__observers_task: Optional[asyncio.Task] = None
@property
def id(self) -> int:
@@ -219,8 +223,8 @@ class FrameProcessor(BaseObject):
async def cleanup(self):
await super().cleanup()
await self.__cancel_input_task()
await self.__cancel_process_task()
await self.__cancel_process_urgent_task()
await self.__cancel_push_task()
await self.__cancel_observers_task()
def link(self, processor: "FrameProcessor"):
self._next = processor
@@ -265,7 +269,7 @@ class FrameProcessor(BaseObject):
await self.process_frame(frame, direction)
else:
# We queue everything else.
await self.__streaming_queue.put(FrameProcessorQueueItem(frame, direction, callback))
await self.__input_queue.put((frame, direction, callback))
async def pause_processing_frames(self):
logger.trace(f"{self}: pausing frame processing")
@@ -273,7 +277,7 @@ class FrameProcessor(BaseObject):
async def resume_processing_frames(self):
logger.trace(f"{self}: resuming frame processing")
self.__streaming_event.set()
self.__input_event.set()
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
@@ -283,7 +287,9 @@ class FrameProcessor(BaseObject):
self._enable_metrics = frame.enable_metrics
self._enable_usage_metrics = frame.enable_usage_metrics
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
self._observer = frame.observer
self._observer_proxies = self._create_observer_proxies(
frame.task_manager, frame.observers
)
await self.__start(frame)
elif isinstance(frame, StartInterruptionFrame):
await self._start_interruption()
@@ -300,48 +306,21 @@ class FrameProcessor(BaseObject):
if not self._check_ready(frame):
return
try:
timestamp = self._clock.get_time() if self._clock else 0
if direction == FrameDirection.DOWNSTREAM and self._next:
logger.trace(f"Pushing {frame} from {self} to {self._next}")
if self._observer:
data = FramePushed(
source=self,
destination=self._next,
frame=frame,
direction=direction,
timestamp=timestamp,
)
await self._observer.on_push_frame(data)
await self._next.queue_frame(frame, direction)
elif direction == FrameDirection.UPSTREAM and self._prev:
logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}")
if self._observer:
data = FramePushed(
source=self,
destination=self._prev,
frame=frame,
direction=direction,
timestamp=timestamp,
)
await self._observer.on_push_frame(data)
await self._prev.queue_frame(frame, direction)
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
if isinstance(frame, SystemFrame):
await self.__internal_push_frame(frame, direction)
else:
await self.__push_queue.put((frame, direction))
async def __start(self, frame: StartFrame):
self.__create_process_task()
self.__create_process_urgent_task()
self.__create_observers_task()
self.__create_input_task()
self.__create_push_task()
async def __cancel(self, frame: CancelFrame):
self._cancelling = True
await self.__cancel_input_task()
await self.__cancel_process_task()
await self.__cancel_process_urgent_task()
await self.__cancel_push_task()
await self.__cancel_observers_task()
#
# Handle interruptions
@@ -349,32 +328,56 @@ class FrameProcessor(BaseObject):
async def _start_interruption(self):
try:
# Cancel the streaming task.
# Cancel the push frame task. This will stop pushing frames downstream.
await self.__cancel_push_task()
# Cancel the input task. This will stop processing queued frames.
await self.__cancel_input_task()
# Cancel the task processing frames. We do not cancel the task that
# is processing urgent frames.
await self.__cancel_process_task()
# If there's an interruption we should not block frames anymore.
self.__should_block_frames = False
# Clear the streaming queue, since we don't want to process its
# frame anymore (except system and urgent frames).
self.__streaming_queue.clear()
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
# Create a new tasks.
self.__create_process_task()
# Create a new input queue and task.
self.__create_input_task()
# Create a new output queue and task.
self.__create_push_task()
async def _stop_interruption(self):
# Nothing to do right now.
pass
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
try:
timestamp = self._clock.get_time() if self._clock else 0
if direction == FrameDirection.DOWNSTREAM and self._next:
logger.trace(f"Pushing {frame} from {self} to {self._next}")
data = FramePushed(
source=self,
destination=self._next,
frame=frame,
direction=direction,
timestamp=timestamp,
)
await self.__observe_frame(data)
await self._next.queue_frame(frame, direction)
elif direction == FrameDirection.UPSTREAM and self._prev:
logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}")
data = FramePushed(
source=self,
destination=self._prev,
frame=frame,
direction=direction,
timestamp=timestamp,
)
await self.__observe_frame(data)
await self._prev.queue_frame(frame, direction)
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
def _check_ready(self, frame: Frame):
# If we are trying to push a frame but we still have no clock, it means
# we didn't process a StartFrame.
@@ -386,60 +389,82 @@ class FrameProcessor(BaseObject):
return True
def __create_input_task(self):
if not self.__streaming_frame_task:
self.__streaming_frame_task = self.create_task(self.__streaming_frame_task_handler())
if not self.__input_frame_task:
self.__should_block_frames = False
self.__input_event.clear()
self.__input_queue = asyncio.Queue()
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
async def __cancel_input_task(self):
if self.__streaming_frame_task:
await self.cancel_task(self.__streaming_frame_task)
self.__streaming_frame_task = None
if self.__input_frame_task:
await self.cancel_task(self.__input_frame_task)
self.__input_frame_task = None
def __create_process_task(self):
if not self.__process_task:
self.__process_queue = asyncio.Queue()
self.__process_task = self.create_task(
self.__process_task_handler(self.__process_queue)
)
async def __cancel_process_task(self):
if self.__process_task:
await self.cancel_task(self.__process_task)
self.__process_task = None
def __create_process_urgent_task(self):
if not self.__process_urgent_task:
self.__process_urgent_task = self.create_task(
self.__process_task_handler(self.__process_urgent_queue)
)
async def __cancel_process_urgent_task(self):
if self.__process_urgent_task:
await self.cancel_task(self.__process_urgent_task)
self.__process_urgent_task = None
async def __streaming_frame_task_handler(self):
async def __input_frame_task_handler(self):
while True:
if self.__should_block_frames:
logger.trace(f"{self}: frame processing paused")
await self.__streaming_event.wait()
self.__streaming_event.clear()
await self.__input_event.wait()
self.__input_event.clear()
self.__should_block_frames = False
logger.trace(f"{self}: frame processing resumed")
item = await self.__streaming_queue.get()
if isinstance(item.frame, SystemFrame):
await self.__process_urgent_queue.put(item)
else:
await self.__process_queue.put(item)
async def __process_task_handler(self, queue: asyncio.Queue):
while True:
item = await queue.get()
(frame, direction, callback) = await self.__input_queue.get()
# Process the frame.
await self.process_frame(item.frame, item.direction)
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
if item.callback:
await item.callback(self, item.frame, item.direction)
if callback:
await callback(self, frame, direction)
self.__input_queue.task_done()
def __create_push_task(self):
if not self.__push_frame_task:
self.__push_queue = asyncio.Queue()
self.__push_frame_task = self.create_task(self.__push_frame_task_handler())
async def __cancel_push_task(self):
if self.__push_frame_task:
await self.cancel_task(self.__push_frame_task)
self.__push_frame_task = None
async def __push_frame_task_handler(self):
while True:
(frame, direction) = await self.__push_queue.get()
await self.__internal_push_frame(frame, direction)
self.__push_queue.task_done()
def _create_observer_proxies(
self, task_manager: TaskManager, observers: Sequence[BaseObserver]
) -> List[FrameObserverProxy]:
result = []
for observer in observers:
proxy = FrameObserverProxy(task_manager, observer)
proxy.start()
result.append(proxy)
return result
def __create_observers_task(self):
if not self.__observers_task:
self.__observers_queue = asyncio.Queue()
self.__observers_task = self.create_task(self.__observers_task_handler())
async def __cancel_observers_task(self):
if self.__observers_task:
for proxy in self._observer_proxies:
await proxy.stop()
await self.cancel_task(self.__observers_task)
self.__observers_task = None
async def __observe_frame(self, data: FramePushed):
await self.__observers_queue.put(data)
async def __observers_task_handler(self):
while True:
data = await self.__observers_queue.get()
# Proxy observation to all observers.
for observer in self._observer_proxies:
await observer.observe(data)
self.__observers_queue.task_done()

View File

@@ -254,7 +254,7 @@ class RTVIBotReady(BaseModel):
class RTVILLMFunctionCallMessageData(BaseModel):
function_name: str
tool_call_id: str
args: Mapping[str, Any]
arguments: Mapping[str, Any]
class RTVILLMFunctionCallMessage(BaseModel):
@@ -700,7 +700,7 @@ class RTVIProcessor(FrameProcessor):
fn = RTVILLMFunctionCallMessageData(
function_name=params.function_name,
tool_call_id=params.tool_call_id,
args=params.arguments,
arguments=params.arguments,
)
message = RTVILLMFunctionCallMessage(data=fn)
await self._push_transport_message(message, exclude_none=False)

View File

@@ -250,24 +250,14 @@ class AnthropicLLMService(LLMService):
if hasattr(event.message.usage, "output_tokens")
else 0
)
cache_creation_input_tokens += (
event.message.usage.cache_creation_input_tokens
if (
hasattr(event.message.usage, "cache_creation_input_tokens")
and event.message.usage.cache_creation_input_tokens is not None
if hasattr(event.message.usage, "cache_creation_input_tokens"):
cache_creation_input_tokens += (
event.message.usage.cache_creation_input_tokens
)
else 0
)
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
cache_read_input_tokens += (
event.message.usage.cache_read_input_tokens
if (
hasattr(event.message.usage, "cache_read_input_tokens")
and event.message.usage.cache_read_input_tokens is not None
)
else 0
)
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
if hasattr(event.message.usage, "cache_read_input_tokens"):
cache_read_input_tokens += event.message.usage.cache_read_input_tokens
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
total_input_tokens = (
prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
)

View File

@@ -8,8 +8,6 @@ import sys
from pipecat.services import DeprecatedModuleProxy
from .llm import *
from .stt import *
from .tts import *
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]")
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")

View File

@@ -1,785 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import base64
import copy
import io
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.frames.frames import (
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
UserImageRawFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import LLMService
try:
import boto3
import httpx
from botocore.config import Config
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
)
raise Exception(f"Missing module: {e}")
@dataclass
class AWSBedrockContextAggregatorPair:
_user: "AWSBedrockUserContextAggregator"
_assistant: "AWSBedrockAssistantContextAggregator"
def user(self) -> "AWSBedrockUserContextAggregator":
return self._user
def assistant(self) -> "AWSBedrockAssistantContextAggregator":
return self._assistant
class AWSBedrockLLMContext(OpenAILLMContext):
def __init__(
self,
messages: Optional[List[dict]] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[dict] = None,
*,
system: Optional[str] = None,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self.system = system
@staticmethod
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext":
logger.debug(f"Upgrading to AWS Bedrock: {obj}")
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext):
obj.__class__ = AWSBedrockLLMContext
obj._restructure_from_openai_messages()
else:
obj._restructure_from_bedrock_messages()
return obj
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
self = cls(
messages=openai_context.messages,
tools=openai_context.tools,
tool_choice=openai_context.tool_choice,
)
self.set_llm_adapter(openai_context.get_llm_adapter())
self._restructure_from_openai_messages()
return self
@classmethod
def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext":
self = cls(messages=messages)
self._restructure_from_openai_messages()
return self
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext":
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
return context
def set_messages(self, messages: List):
self._messages[:] = messages
self._restructure_from_openai_messages()
# convert a message in AWS Bedrock format into one or more messages in OpenAI format
def to_standard_messages(self, obj):
"""Convert AWS Bedrock message format to standard structured format.
Handles text content and function calls for both user and assistant messages.
Args:
obj: Message in AWS Bedrock format:
{
"role": "user/assistant",
"content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}]
}
Returns:
List of messages in standard format:
[
{
"role": "user/assistant/tool",
"content": [{"type": "text", "text": str}]
}
]
"""
role = obj.get("role")
content = obj.get("content")
if role == "assistant":
if isinstance(content, str):
return [{"role": role, "content": [{"type": "text", "text": content}]}]
elif isinstance(content, list):
text_items = []
tool_items = []
for item in content:
if "text" in item:
text_items.append({"type": "text", "text": item["text"]})
elif "toolUse" in item:
tool_use = item["toolUse"]
tool_items.append(
{
"type": "function",
"id": tool_use["toolUseId"],
"function": {
"name": tool_use["name"],
"arguments": json.dumps(tool_use["input"]),
},
}
)
messages = []
if text_items:
messages.append({"role": role, "content": text_items})
if tool_items:
messages.append({"role": role, "tool_calls": tool_items})
return messages
elif role == "user":
if isinstance(content, str):
return [{"role": role, "content": [{"type": "text", "text": content}]}]
elif isinstance(content, list):
text_items = []
tool_items = []
for item in content:
if "text" in item:
text_items.append({"type": "text", "text": item["text"]})
elif "toolResult" in item:
tool_result = item["toolResult"]
# Extract content from toolResult
result_content = ""
if isinstance(tool_result["content"], list):
for content_item in tool_result["content"]:
if "text" in content_item:
result_content = content_item["text"]
elif "json" in content_item:
result_content = json.dumps(content_item["json"])
else:
result_content = tool_result["content"]
tool_items.append(
{
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": result_content,
}
)
messages = []
if text_items:
messages.append({"role": role, "content": text_items})
messages.extend(tool_items)
return messages
def from_standard_message(self, message):
"""Convert standard format message to AWS Bedrock format.
Handles conversion of text content, tool calls, and tool results.
Empty text content is converted to "(empty)".
Args:
message: Message in standard format:
{
"role": "user/assistant/tool",
"content": str | [{"type": "text", ...}],
"tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}]
}
Returns:
Message in AWS Bedrock format:
{
"role": "user/assistant",
"content": [
{"text": str} |
{"toolUse": {"toolUseId": str, "name": str, "input": dict}} |
{"toolResult": {"toolUseId": str, "content": [...], "status": str}}
]
}
"""
if message["role"] == "tool":
# Try to parse the content as JSON if it looks like JSON
try:
if message["content"].strip().startswith("{") and message[
"content"
].strip().endswith("}"):
content_json = json.loads(message["content"])
tool_result_content = [{"json": content_json}]
else:
tool_result_content = [{"text": message["content"]}]
except:
tool_result_content = [{"text": message["content"]}]
return {
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": message["tool_call_id"],
"content": tool_result_content,
},
},
],
}
if message.get("tool_calls"):
tc = message["tool_calls"]
ret = {"role": "assistant", "content": []}
for tool_call in tc:
function = tool_call["function"]
arguments = json.loads(function["arguments"])
new_tool_use = {
"toolUse": {
"toolUseId": tool_call["id"],
"name": function["name"],
"input": arguments,
}
}
ret["content"].append(new_tool_use)
return ret
# Handle text content
content = message.get("content")
if isinstance(content, str):
if content == "":
return {"role": message["role"], "content": [{"text": "(empty)"}]}
else:
return {"role": message["role"], "content": [{"text": content}]}
elif isinstance(content, list):
new_content = []
for item in content:
if item.get("type", "") == "text":
text_content = item["text"] if item["text"] != "" else "(empty)"
new_content.append({"text": text_content})
return {"role": message["role"], "content": new_content}
return message
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Image should be the first content block in the message
content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}]
if text:
content.append({"text": text})
self.add_message({"role": "user", "content": content})
def add_message(self, message):
try:
if self.messages:
# AWS Bedrock requires that roles alternate. If this message's
# role is the same as the last message, we should add this
# message's content to the last message.
if self.messages[-1]["role"] == message["role"]:
# if the last message has just a content string, convert it to a list
# in the proper format
if isinstance(self.messages[-1]["content"], str):
self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}]
# if this message has just a content string, convert it to a list
# in the proper format
if isinstance(message["content"], str):
message["content"] = [{"text": message["content"]}]
# append the content of this message to the last message
self.messages[-1]["content"].extend(message["content"])
else:
self.messages.append(message)
else:
self.messages.append(message)
except Exception as e:
logger.error(f"Error adding message: {e}")
def _restructure_from_bedrock_messages(self):
"""Restructure messages in AWS Bedrock format by handling system
messages, merging consecutive messages with the same role, and ensuring
proper content formatting.
"""
# Handle system message if present at the beginning
if self.messages and self.messages[0]["role"] == "system":
if len(self.messages) == 1:
self.messages[0]["role"] = "user"
else:
system_content = self.messages.pop(0)["content"]
if isinstance(system_content, str):
system_content = [{"text": system_content}]
if self.system:
if isinstance(self.system, str):
self.system = [{"text": self.system}]
self.system.extend(system_content)
else:
self.system = system_content
# Ensure content is properly formatted
for msg in self.messages:
if isinstance(msg["content"], str):
msg["content"] = [{"text": msg["content"]}]
elif not msg["content"]:
msg["content"] = [{"text": "(empty)"}]
elif isinstance(msg["content"], list):
for idx, item in enumerate(msg["content"]):
if isinstance(item, dict) and "text" in item and item["text"] == "":
item["text"] = "(empty)"
elif isinstance(item, str) and item == "":
msg["content"][idx] = {"text": "(empty)"}
# Merge consecutive messages with the same role
merged_messages = []
for msg in self.messages:
if merged_messages and merged_messages[-1]["role"] == msg["role"]:
merged_messages[-1]["content"].extend(msg["content"])
else:
merged_messages.append(msg)
self.messages.clear()
self.messages.extend(merged_messages)
def _restructure_from_openai_messages(self):
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
try:
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# See if we should pull the system message out of our context.messages list. (For
# compatibility with Open AI messages format.)
if self.messages and self.messages[0]["role"] == "system":
self.system = self.messages[0]["content"]
self.messages.pop(0)
# Merge consecutive messages with the same role.
i = 0
while i < len(self.messages) - 1:
current_message = self.messages[i]
next_message = self.messages[i + 1]
if current_message["role"] == next_message["role"]:
# Convert content to list of dictionaries if it's a string
if isinstance(current_message["content"], str):
current_message["content"] = [
{"type": "text", "text": current_message["content"]}
]
if isinstance(next_message["content"], str):
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
# Concatenate the content
current_message["content"].extend(next_message["content"])
# Remove the next message from the list
self.messages.pop(i + 1)
else:
i += 1
# Avoid empty content in messages
for message in self.messages:
if isinstance(message["content"], str) and message["content"] == "":
message["content"] = "(empty)"
elif isinstance(message["content"], list) and len(message["content"]) == 0:
message["content"] = [{"type": "text", "text": "(empty)"}]
def get_messages_for_persistent_storage(self):
messages = super().get_messages_for_persistent_storage()
if self.system:
messages.insert(0, {"role": "system", "content": self.system})
return messages
def get_messages_for_logging(self) -> str:
msgs = []
for message in self.messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item.get("image"):
item["source"]["bytes"] = "..."
msgs.append(msg)
return json.dumps(msgs)
class AWSBedrockUserContextAggregator(LLMUserContextAggregator):
pass
class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
# Format tool use according to AWS Bedrock API
self._context.add_message(
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": frame.tool_call_id,
"name": frame.function_name,
"input": frame.arguments if frame.arguments else {},
}
}
],
}
)
self._context.add_message(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": frame.tool_call_id,
"content": [{"text": "IN_PROGRESS"}],
}
}
],
}
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if message["role"] == "user":
for content in message["content"]:
if (
isinstance(content, dict)
and content.get("toolResult")
and content["toolResult"]["toolUseId"] == tool_call_id
):
content["toolResult"]["content"] = [{"text": result}]
async def handle_user_image_frame(self, frame: UserImageRawFrame):
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)
class AWSBedrockLLMService(LLMService):
"""This class implements inference with AWS Bedrock models including Amazon
Nova and Anthropic Claude.
Requires AWS credentials to be configured in the environment or through
boto3 configuration.
"""
# Overriding the default adapter to use the Anthropic one.
adapter_class = AWSBedrockLLMAdapter
class InputParams(BaseModel):
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0)
top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0)
stop_sequences: Optional[List[str]] = Field(default_factory=lambda: [])
latency: Optional[str] = Field(default_factory=lambda: "standard")
additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict)
def __init__(
self,
*,
aws_access_key: Optional[str] = None,
aws_secret_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region: str = "us-east-1",
model: str,
params: InputParams = InputParams(),
client_config: Optional[Config] = None,
**kwargs,
):
super().__init__(**kwargs)
# Initialize the AWS Bedrock client
if not client_config:
client_config = Config(
connect_timeout=300, # 5 minutes
read_timeout=300, # 5 minutes
retries={"max_attempts": 3},
)
session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
region_name=aws_region,
)
self._client = session.client(service_name="bedrock-runtime", config=client_config)
self.set_model_name(model)
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"top_p": params.top_p,
"latency": params.latency,
"additional_model_request_fields": params.additional_model_request_fields
if isinstance(params.additional_model_request_fields, dict)
else {},
}
logger.info(f"Using AWS Bedrock model: {model}")
def can_generate_metrics(self) -> bool:
return True
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AWSBedrockContextAggregatorPair:
"""Create an instance of AWSBedrockContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
AWSBedrockContextAggregatorPair: A pair of context aggregators, one
for the user and one for the assistant, encapsulated in an
AWSBedrockContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
if isinstance(context, OpenAILLMContext):
context = AWSBedrockLLMContext.from_openai_context(context)
user = AWSBedrockUserContextAggregator(context, params=user_params)
assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params)
return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant)
async def _process_context(self, context: AWSBedrockLLMContext):
# Usage tracking
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
cache_read_input_tokens = 0
cache_creation_input_tokens = 0
use_completion_tokens_estimate = False
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Set up inference config
inference_config = {
"maxTokens": self._settings["max_tokens"],
"temperature": self._settings["temperature"],
"topP": self._settings["top_p"],
}
# Prepare request parameters
request_params = {
"modelId": self.model_name,
"messages": context.messages,
"inferenceConfig": inference_config,
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
}
# Add system message
request_params["system"] = context.system
# Add tools if present
if context.tools:
tool_config = {"tools": context.tools}
# Add tool_choice if specified
if context.tool_choice:
if context.tool_choice == "auto":
tool_config["toolChoice"] = {"auto": {}}
elif context.tool_choice == "none":
# Skip adding toolChoice for "none"
pass
elif (
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
):
tool_config["toolChoice"] = {
"tool": {"name": context.tool_choice["function"]["name"]}
}
request_params["toolConfig"] = tool_config
# Add performance config if latency is specified
if self._settings["latency"] in ["standard", "optimized"]:
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
logger.debug(f"Calling AWS Bedrock model with: {request_params}")
# Call AWS Bedrock with streaming
response = self._client.converse_stream(**request_params)
await self.stop_ttfb_metrics()
# Process the streaming response
tool_use_block = None
json_accumulator = ""
for event in response["stream"]:
# Handle text content
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
await self.push_frame(LLMTextFrame(delta["text"]))
completion_tokens_estimate += self._estimate_tokens(delta["text"])
elif "toolUse" in delta and "input" in delta["toolUse"]:
# Handle partial JSON for tool use
json_accumulator += delta["toolUse"]["input"]
completion_tokens_estimate += self._estimate_tokens(
delta["toolUse"]["input"]
)
# Handle tool use start
elif "contentBlockStart" in event:
content_block_start = event["contentBlockStart"]["start"]
if "toolUse" in content_block_start:
tool_use_block = {
"id": content_block_start["toolUse"].get("toolUseId", ""),
"name": content_block_start["toolUse"].get("name", ""),
}
json_accumulator = ""
# Handle message completion with tool use
elif "messageStop" in event and "stopReason" in event["messageStop"]:
if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block:
try:
arguments = json.loads(json_accumulator) if json_accumulator else {}
await self.call_function(
context=context,
tool_call_id=tool_use_block["id"],
function_name=tool_use_block["name"],
arguments=arguments,
)
except json.JSONDecodeError:
logger.error(f"Failed to parse tool arguments: {json_accumulator}")
# Handle usage metrics if available
if "metadata" in event and "usage" in event["metadata"]:
usage = event["metadata"]["usage"]
prompt_tokens += usage.get("inputTokens", 0)
completion_tokens += usage.get("outputTokens", 0)
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
# token estimate. The reraise the exception so all the processors running in this task
# also get cancelled.
use_completion_tokens_estimate = True
raise
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
comp_tokens = (
completion_tokens
if not use_completion_tokens_estimate
else completion_tokens_estimate
)
await self._report_usage_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=comp_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
elif isinstance(frame, LLMMessagesFrame):
context = AWSBedrockLLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
# This is only useful in very simple pipelines because it creates
# a new context. Generally we want a context manager to catch
# UserImageRawFrames coming through the pipeline and add them
# to the context.
context = AWSBedrockLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)
if context:
await self._process_context(context)
def _estimate_tokens(self, text: str) -> int:
return int(len(re.split(r"[^\w]+", text)) * 1.3)
async def _report_usage_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int,
):
if prompt_tokens or completion_tokens:
tokens = LLMTokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
)
await self.start_llm_usage_metrics(tokens)

View File

@@ -1,329 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import json
import os
import random
import string
from typing import AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
raise Exception(f"Missing module: {e}")
class AWSTranscribeSTTService(STTService):
def __init__(
self,
*,
api_key: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_session_token: Optional[str] = None,
region: Optional[str] = "us-east-1",
sample_rate: int = 16000,
language: Language = Language.EN,
**kwargs,
):
super().__init__(**kwargs)
self._settings = {
"sample_rate": sample_rate,
"language": language,
"media_encoding": "linear16", # AWS expects raw PCM
"number_of_channels": 1,
"show_speaker_label": False,
"enable_channel_identification": False,
}
# Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz
if sample_rate not in [8000, 16000]:
logger.warning(
f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz."
)
self._settings["sample_rate"] = 16000
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
self._ws_client = None
self._connection_lock = asyncio.Lock()
self._connecting = False
self._receive_task = None
def get_service_encoding(self, encoding: str) -> str:
"""Convert internal encoding format to AWS Transcribe format."""
encoding_map = {
"linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM
}
return encoding_map.get(encoding, encoding)
async def start(self, frame: StartFrame):
"""Initialize the connection when the service starts."""
await super().start(frame)
logger.info("Starting AWS Transcribe service...")
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
await self._connect()
if self._ws_client and self._ws_client.open:
logger.info("Successfully established WebSocket connection")
return
logger.warning("WebSocket connection not established after connect")
except Exception as e:
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
retry_count += 1
if retry_count < max_retries:
await asyncio.sleep(1) # Wait before retrying
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data and send to AWS Transcribe"""
try:
# Ensure WebSocket is connected
if not self._ws_client or not self._ws_client.open:
logger.debug("WebSocket not connected, attempting to reconnect...")
try:
await self._connect()
except Exception as e:
logger.error(f"Failed to reconnect: {e}")
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
return
# Format the audio data according to AWS event stream format
event_message = build_event_message(audio)
# Send the formatted event message
try:
await self._ws_client.send(event_message)
# Start metrics after first chunk sent
await self.start_processing_metrics()
await self.start_ttfb_metrics()
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Connection closed while sending: {e}")
await self._disconnect()
# Don't yield error here - we'll retry on next frame
except Exception as e:
logger.error(f"Error sending audio: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
await self._disconnect()
except Exception as e:
logger.error(f"Error in run_stt: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
await self._disconnect()
async def _connect(self):
"""Connect to AWS Transcribe with connection state management."""
if self._ws_client and self._ws_client.open and self._receive_task:
logger.debug(f"{self} Already connected")
return
async with self._connection_lock:
if self._connecting:
logger.debug(f"{self} Connection already in progress")
return
try:
self._connecting = True
logger.debug(f"{self} Starting connection process...")
if self._ws_client:
await self._disconnect()
language_code = self.language_to_service_language(
Language(self._settings["language"])
)
if not language_code:
raise ValueError(f"Unsupported language: {self._settings['language']}")
# Generate random websocket key
websocket_key = "".join(
random.choices(
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
)
)
# Add required headers
extra_headers = {
"Origin": "https://localhost",
"Sec-WebSocket-Key": websocket_key,
"Sec-WebSocket-Version": "13",
"Connection": "keep-alive",
}
# Get presigned URL
presigned_url = get_presigned_url(
region=self._credentials["region"],
credentials={
"access_key": self._credentials["aws_access_key_id"],
"secret_key": self._credentials["aws_secret_access_key"],
"session_token": self._credentials["aws_session_token"],
},
language_code=language_code,
media_encoding=self.get_service_encoding(
self._settings["media_encoding"]
), # Convert to AWS format
sample_rate=self._settings["sample_rate"],
number_of_channels=self._settings["number_of_channels"],
enable_partial_results_stabilization=True,
partial_results_stability="high",
show_speaker_label=self._settings["show_speaker_label"],
enable_channel_identification=self._settings["enable_channel_identification"],
)
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
# Connect with the required headers and settings
self._ws_client = await websockets.connect(
presigned_url,
extra_headers=extra_headers,
subprotocols=["mqtt"],
ping_interval=None,
ping_timeout=None,
compression=None,
)
logger.debug(f"{self} WebSocket connected, starting receive task...")
# Start receive task
self._receive_task = self.create_task(self._receive_loop())
logger.info(f"{self} Successfully connected to AWS Transcribe")
except Exception as e:
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
await self._disconnect()
raise
finally:
self._connecting = False
async def _disconnect(self):
"""Disconnect from AWS Transcribe."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
try:
if self._ws_client and self._ws_client.open:
# Send end-stream message
end_stream = {"message-type": "event", "event": "end"}
await self._ws_client.send(json.dumps(end_stream))
await self._ws_client.close()
except Exception as e:
logger.warning(f"{self} Error closing WebSocket connection: {e}")
finally:
self._ws_client = None
def language_to_service_language(self, language: Language) -> str | None:
"""Convert internal language enum to AWS Transcribe language code."""
language_map = {
Language.EN: "en-US",
Language.ES: "es-US",
Language.FR: "fr-FR",
Language.DE: "de-DE",
Language.IT: "it-IT",
Language.PT: "pt-BR",
Language.JA: "ja-JP",
Language.KO: "ko-KR",
Language.ZH: "zh-CN",
}
return language_map.get(language)
async def _receive_loop(self):
"""Background task to receive and process messages from AWS Transcribe."""
while True:
if not self._ws_client or not self._ws_client.open:
logger.warning(f"{self} WebSocket closed in receive loop")
break
try:
response = await self._ws_client.recv()
headers, payload = decode_event(response)
if headers.get(":message-type") == "event":
# Process transcription results
results = payload.get("Transcript", {}).get("Results", [])
if results:
result = results[0]
alternatives = result.get("Alternatives", [])
if alternatives:
transcript = alternatives[0].get("Transcript", "")
is_final = not result.get("IsPartial", True)
if transcript:
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._settings["language"],
)
)
await self.stop_processing_metrics()
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._settings["language"],
)
)
elif headers.get(":message-type") == "exception":
error_msg = payload.get("Message", "Unknown error")
logger.error(f"{self} Exception from AWS: {error_msg}")
await self.push_frame(
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
)
else:
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
)
break
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")
break

View File

@@ -5,7 +5,6 @@
#
import asyncio
import os
from typing import AsyncGenerator, Optional
from loguru import logger
@@ -27,7 +26,9 @@ try:
from botocore.exceptions import BotoCoreError, ClientError
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
logger.error(
"In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
)
raise Exception(f"Missing module: {e}")
@@ -107,7 +108,7 @@ def language_to_aws_language(language: Language) -> Optional[str]:
return language_map.get(language)
class AWSPollyTTSService(TTSService):
class PollyTTSService(TTSService):
class InputParams(BaseModel):
engine: Optional[str] = None
language: Optional[Language] = Language.EN
@@ -150,24 +151,6 @@ class AWSPollyTTSService(TTSService):
self.set_voice(voice_id)
# Get credentials from environment variables if not provided
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
# Validate that we have the required credentials
if (
not self._credentials["aws_access_key_id"]
or not self._credentials["aws_secret_access_key"]
):
raise ValueError(
"AWS credentials not found. Please provide them either through constructor parameters "
"or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
)
def can_generate_metrics(self) -> bool:
return True
@@ -182,17 +165,18 @@ class AWSPollyTTSService(TTSService):
prosody_attrs = []
# Prosody tags are only supported for standard and neural engines
if self._settings["engine"] == "standard":
if self._settings["engine"] != "generative":
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["pitch"]:
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
if prosody_attrs:
ssml += f"<prosody {' '.join(prosody_attrs)}>"
if prosody_attrs:
ssml += f"<prosody {' '.join(prosody_attrs)}>"
else:
logger.warning("Prosody tags are not supported for generative engine. Ignoring.")
ssml += text
@@ -203,8 +187,6 @@ class AWSPollyTTSService(TTSService):
ssml += "</speak>"
logger.trace(f"{self} SSML: {ssml}")
return ssml
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
@@ -266,17 +248,3 @@ class AWSPollyTTSService(TTSService):
finally:
yield TTSStoppedFrame()
class PollyTTSService(AWSPollyTTSService):
def __init__(self, **kwargs):
super().__init__(**kwargs)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.",
DeprecationWarning,
)

View File

@@ -1,261 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import binascii
import datetime
import hashlib
import hmac
import json
import struct
import urllib.parse
from typing import Dict, Optional
def get_presigned_url(
*,
region: str,
credentials: Dict[str, Optional[str]],
language_code: str,
media_encoding: str = "pcm",
sample_rate: int = 16000,
number_of_channels: int = 1,
enable_partial_results_stabilization: bool = True,
partial_results_stability: str = "high",
vocabulary_name: Optional[str] = None,
vocabulary_filter_name: Optional[str] = None,
show_speaker_label: bool = False,
enable_channel_identification: bool = False,
) -> str:
"""Create a presigned URL for AWS Transcribe streaming."""
access_key = credentials.get("access_key")
secret_key = credentials.get("secret_key")
session_token = credentials.get("session_token")
if not access_key or not secret_key:
raise ValueError("AWS credentials are required")
# Initialize the URL generator
url_generator = AWSTranscribePresignedURL(
access_key=access_key, secret_key=secret_key, session_token=session_token, region=region
)
# Get the presigned URL
return url_generator.get_request_url(
sample_rate=sample_rate,
language_code=language_code,
media_encoding=media_encoding,
vocabulary_name=vocabulary_name,
vocabulary_filter_name=vocabulary_filter_name,
show_speaker_label=show_speaker_label,
enable_channel_identification=enable_channel_identification,
number_of_channels=number_of_channels,
enable_partial_results_stabilization=enable_partial_results_stabilization,
partial_results_stability=partial_results_stability,
)
class AWSTranscribePresignedURL:
def __init__(
self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1"
):
self.access_key = access_key
self.secret_key = secret_key
self.session_token = session_token
self.method = "GET"
self.service = "transcribe"
self.region = region
self.endpoint = ""
self.host = ""
self.amz_date = ""
self.datestamp = ""
self.canonical_uri = "/stream-transcription-websocket"
self.canonical_headers = ""
self.signed_headers = "host"
self.algorithm = "AWS4-HMAC-SHA256"
self.credential_scope = ""
self.canonical_querystring = ""
self.payload_hash = ""
self.canonical_request = ""
self.string_to_sign = ""
self.signature = ""
self.request_url = ""
def get_request_url(
self,
sample_rate: int,
language_code: str = "",
media_encoding: str = "pcm",
vocabulary_name: str = "",
vocabulary_filter_name: str = "",
show_speaker_label: bool = False,
enable_channel_identification: bool = False,
number_of_channels: int = 1,
enable_partial_results_stabilization: bool = False,
partial_results_stability: str = "",
) -> str:
self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443"
self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443"
now = datetime.datetime.utcnow()
self.amz_date = now.strftime("%Y%m%dT%H%M%SZ")
self.datestamp = now.strftime("%Y%m%d")
self.canonical_headers = f"host:{self.host}\n"
self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request"
# Create canonical querystring
self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm
self.canonical_querystring += (
"&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope
)
self.canonical_querystring += "&X-Amz-Date=" + self.amz_date
self.canonical_querystring += "&X-Amz-Expires=300"
if self.session_token:
self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote(
self.session_token, safe=""
)
self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers
if enable_channel_identification:
self.canonical_querystring += "&enable-channel-identification=true"
if enable_partial_results_stabilization:
self.canonical_querystring += "&enable-partial-results-stabilization=true"
if language_code:
self.canonical_querystring += "&language-code=" + language_code
if media_encoding:
self.canonical_querystring += "&media-encoding=" + media_encoding
if number_of_channels > 1:
self.canonical_querystring += "&number-of-channels=" + str(number_of_channels)
if partial_results_stability:
self.canonical_querystring += "&partial-results-stability=" + partial_results_stability
if sample_rate:
self.canonical_querystring += "&sample-rate=" + str(sample_rate)
if show_speaker_label:
self.canonical_querystring += "&show-speaker-label=true"
if vocabulary_filter_name:
self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name
if vocabulary_name:
self.canonical_querystring += "&vocabulary-name=" + vocabulary_name
# Create payload hash
self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()
# Create canonical request
self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}"
# Create string to sign
credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request"
string_to_sign = (
f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n"
+ hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest()
)
# Calculate signature
k_date = hmac.new(
f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256
).digest()
k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest()
k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
self.signature = hmac.new(
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
).hexdigest()
# Add signature to query string
self.canonical_querystring += "&X-Amz-Signature=" + self.signature
# Create request URL
self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring
return self.request_url
def get_headers(header_name: str, header_value: str) -> bytearray:
"""Build a header following AWS event stream format."""
name = header_name.encode("utf-8")
name_byte_length = bytes([len(name)])
value_type = bytes([7]) # 7 represents a string
value = header_value.encode("utf-8")
value_byte_length = struct.pack(">H", len(value))
# Construct the header
header_list = bytearray()
header_list.extend(name_byte_length)
header_list.extend(name)
header_list.extend(value_type)
header_list.extend(value_byte_length)
header_list.extend(value)
return header_list
def build_event_message(payload: bytes) -> bytes:
"""
Build an event message for AWS Transcribe streaming.
Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py
"""
# Build headers
content_type_header = get_headers(":content-type", "application/octet-stream")
event_type_header = get_headers(":event-type", "AudioEvent")
message_type_header = get_headers(":message-type", "event")
headers = bytearray()
headers.extend(content_type_header)
headers.extend(event_type_header)
headers.extend(message_type_header)
# Calculate total byte length and headers byte length
# 16 accounts for 8 byte prelude, 2x 4 byte CRCs
total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16)
headers_byte_length = struct.pack(">I", len(headers))
# Build the prelude
prelude = bytearray([0] * 8)
prelude[:4] = total_byte_length
prelude[4:] = headers_byte_length
# Calculate checksum for prelude
prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF)
# Construct the message
message_as_list = bytearray()
message_as_list.extend(prelude)
message_as_list.extend(prelude_crc)
message_as_list.extend(headers)
message_as_list.extend(payload)
# Calculate checksum for message
message = bytes(message_as_list)
message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF)
# Add message checksum
message_as_list.extend(message_crc)
return bytes(message_as_list)
def decode_event(message):
# Extract the prelude, headers, payload and CRC
prelude = message[:8]
total_length, headers_length = struct.unpack(">II", prelude)
prelude_crc = struct.unpack(">I", message[8:12])[0]
headers = message[12 : 12 + headers_length]
payload = message[12 + headers_length : -4]
message_crc = struct.unpack(">I", message[-4:])[0]
# Check the CRCs
assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed"
assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed"
# Parse the headers
headers_dict = {}
while headers:
name_len = headers[0]
name = headers[1 : 1 + name_len].decode("utf-8")
value_type = headers[1 + name_len]
value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0]
value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8")
headers_dict[name] = value
headers = headers[4 + name_len + value_len :]
return headers_dict, json.loads(payload)

View File

@@ -1 +0,0 @@
from .aws import AWSNovaSonicLLMService

View File

@@ -1,997 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import base64
import json
import time
import uuid
import wave
from dataclasses import dataclass
from enum import Enum
from importlib.resources import files
from typing import Any, List, Optional
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws_nova_sonic.context import (
AWSNovaSonicAssistantContextAggregator,
AWSNovaSonicContextAggregatorPair,
AWSNovaSonicLLMContext,
AWSNovaSonicUserContextAggregator,
Role,
)
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
from pipecat.services.llm_service import LLMService
from pipecat.utils.time import time_now_iso8601
try:
from aws_sdk_bedrock_runtime.client import (
BedrockRuntimeClient,
InvokeModelWithBidirectionalStreamOperationInput,
)
from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
from aws_sdk_bedrock_runtime.models import (
BidirectionalInputPayloadPart,
InvokeModelWithBidirectionalStreamInput,
InvokeModelWithBidirectionalStreamInputChunk,
InvokeModelWithBidirectionalStreamOperationOutput,
InvokeModelWithBidirectionalStreamOutput,
)
from smithy_aws_core.credentials_resolvers.static import StaticCredentialsResolver
from smithy_aws_core.identity import AWSCredentialsIdentity
from smithy_core.aio.eventstream import DuplexEventStream
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use AWS services, you need to `pip install pipecat-ai[aws-nova-sonic]`."
)
raise Exception(f"Missing module: {e}")
class AWSNovaSonicUnhandledFunctionException(Exception):
pass
class ContentType(Enum):
AUDIO = "AUDIO"
TEXT = "TEXT"
TOOL = "TOOL"
class TextStage(Enum):
FINAL = "FINAL" # what has been said
SPECULATIVE = "SPECULATIVE" # what's planned to be said
@dataclass
class CurrentContent:
type: ContentType
role: Role
text_stage: TextStage # None if not text
text_content: str # starts as None, then fills in if text
def __str__(self):
return (
f"CurrentContent(\n"
f" type={self.type.name},\n"
f" role={self.role.name},\n"
f" text_stage={self.text_stage.name if self.text_stage else 'None'}\n"
f")"
)
class Params(BaseModel):
# Audio input
input_sample_rate: Optional[int] = Field(default=16000)
input_sample_size: Optional[int] = Field(default=16)
input_channel_count: Optional[int] = Field(default=1)
# Audio output
output_sample_rate: Optional[int] = Field(default=24000)
output_sample_size: Optional[int] = Field(default=16)
output_channel_count: Optional[int] = Field(default=1)
# Inference
max_tokens: Optional[int] = Field(default=1024)
top_p: Optional[float] = Field(default=0.9)
temperature: Optional[float] = Field(default=0.7)
class AWSNovaSonicLLMService(LLMService):
# Override the default adapter to use the AWSNovaSonicLLMAdapter one
adapter_class = AWSNovaSonicLLMAdapter
def __init__(
self,
*,
secret_access_key: str,
access_key_id: str,
region: str,
model: str = "amazon.nova-sonic-v1:0",
voice_id: str = "matthew", # matthew, tiffany, amy
params: Params = Params(),
system_instruction: Optional[str] = None,
tools: Optional[ToolsSchema] = None,
send_transcription_frames: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._secret_access_key = secret_access_key
self._access_key_id = access_key_id
self._region = region
self._model = model
self._client: Optional[BedrockRuntimeClient] = None
self._voice_id = voice_id
self._params = params
self._system_instruction = system_instruction
self._tools = tools
self._send_transcription_frames = send_transcription_frames
self._context: Optional[AWSNovaSonicLLMContext] = None
self._stream: Optional[
DuplexEventStream[
InvokeModelWithBidirectionalStreamInput,
InvokeModelWithBidirectionalStreamOutput,
InvokeModelWithBidirectionalStreamOperationOutput,
]
] = None
self._receive_task: Optional[asyncio.Task] = None
self._prompt_name: Optional[str] = None
self._input_audio_content_name: Optional[str] = None
self._content_being_received: Optional[CurrentContent] = None
self._assistant_is_responding = False
self._ready_to_send_context = False
self._handling_bot_stopped_speaking = False
self._triggering_assistant_response = False
self._assistant_response_trigger_audio: Optional[bytes] = (
None # Not cleared on _disconnect()
)
self._disconnecting = False
self._connected_time: Optional[float] = None
self._wants_connection = False
#
# standard AIService frame handling
#
async def start(self, frame: StartFrame):
await super().start(frame)
self._wants_connection = True
await self._start_connecting()
async def stop(self, frame: EndFrame):
await super().stop(frame)
self._wants_connection = False
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
self._wants_connection = False
await self._disconnect()
#
# conversation resetting
#
async def reset_conversation(self):
logger.debug("Resetting conversation")
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=False)
# Carry over previous context through disconnect
context = self._context
await self._disconnect()
self._context = context
await self._start_connecting()
#
# frame processing
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
await self._handle_context(frame.context)
elif isinstance(frame, InputAudioRawFrame):
await self._handle_input_audio_frame(frame)
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=True)
elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame):
await self._handle_function_call_result(frame)
await self.push_frame(frame, direction)
async def _handle_context(self, context: OpenAILLMContext):
if not self._context:
# We got our initial context - try to finish connecting
self._context = AWSNovaSonicLLMContext.upgrade_to_nova_sonic(
context, self._system_instruction
)
await self._finish_connecting_if_context_available()
async def _handle_input_audio_frame(self, frame: InputAudioRawFrame):
# Wait until we're done sending the assistant response trigger audio before sending audio
# from the user's mic
if self._triggering_assistant_response:
return
await self._send_user_audio_event(frame.audio)
async def _handle_bot_stopped_speaking(self, delay_to_catch_trailing_assistant_text: bool):
# Protect against back-to-back BotStoppedSpeaking calls, which I've observed
if self._handling_bot_stopped_speaking:
return
self._handling_bot_stopped_speaking = True
async def finalize_assistant_response():
if self._assistant_is_responding:
# Consider the assistant finished with their response (possibly after a short delay,
# to allow for any trailing FINAL assistant text block to come in that need to make
# it into context).
#
# TODO: ideally we could base this solely on the LLM output events, but I couldn't
# figure out a reliable way to determine when we've gotten our last FINAL text block
# after the LLM is done talking.
#
# First I looked at stopReason, but it doesn't seem like the last FINAL text block
# is reliably marked END_TURN (sometimes the *first* one is, but not the last...
# bug?)
#
# Then I considered schemes where we tally or match up SPECULATIVE text blocks with
# FINAL text blocks to know how many or which FINAL blocks to expect, but user
# interruptions throw a wrench in these schemes: depending on the exact timing of
# the interruption, we should or shouldn't expect some FINAL blocks.
if delay_to_catch_trailing_assistant_text:
# This delay length is a balancing act between "catching" trailing assistant
# text that is quite delayed but not waiting so long that user text comes in
# first and results in a bit of context message order scrambling.
await asyncio.sleep(1.25)
self._assistant_is_responding = False
await self._report_assistant_response_ended()
self._handling_bot_stopped_speaking = False
# Finalize the assistant response, either now or after a delay
if delay_to_catch_trailing_assistant_text:
self.create_task(finalize_assistant_response())
else:
await finalize_assistant_response()
async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame):
result = frame.result_frame
await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result)
#
# LLM communication: lifecycle
#
async def _start_connecting(self):
try:
logger.info("Connecting...")
if self._client:
# Here we assume that if we have a client we are connected or connecting
return
# Set IDs for the connection
self._prompt_name = str(uuid.uuid4())
self._input_audio_content_name = str(uuid.uuid4())
# Create the client
self._client = self._create_client()
# Start the bidirectional stream
self._stream = await self._client.invoke_model_with_bidirectional_stream(
InvokeModelWithBidirectionalStreamOperationInput(model_id=self._model)
)
# Send session start event
await self._send_session_start_event()
# Finish connecting
self._ready_to_send_context = True
await self._finish_connecting_if_context_available()
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._disconnect()
async def _finish_connecting_if_context_available(self):
# We can only finish connecting once we've gotten our initial context and we're ready to
# send it
if not (self._context and self._ready_to_send_context):
return
logger.info("Finishing connecting (setting up session)...")
# Read context
history = self._context.get_messages_for_initializing_history()
# Send prompt start event, specifying tools.
# Tools from context take priority over self._tools.
tools = (
self._context.tools
if self._context.tools
else self.get_llm_adapter().from_standard_tools(self._tools)
)
logger.debug(f"Using tools: {tools}")
await self._send_prompt_start_event(tools)
# Send system instruction.
# Instruction from context takes priority over self._system_instruction.
# (NOTE: this prioritizing occurred automatically behind the scenes: the context was
# initialized with self._system_instruction and then updated itself from its messages when
# get_messages_for_initializing_history() was called).
logger.debug(f"Using system instruction: {history.system_instruction}")
if history.system_instruction:
await self._send_text_event(text=history.system_instruction, role=Role.SYSTEM)
# Send conversation history
for message in history.messages:
await self._send_text_event(text=message.text, role=message.role)
# Start audio input
await self._send_audio_input_start_event()
# Start receiving events
self._receive_task = self.create_task(self._receive_task_handler())
# Record finished connecting time (must be done before sending assistant response trigger)
self._connected_time = time.time()
logger.info("Finished connecting")
# If we need to, send assistant response trigger (depends on self._connected_time)
if self._triggering_assistant_response:
await self._send_assistant_response_trigger()
self._triggering_assistant_response = False
async def _disconnect(self):
try:
logger.info("Disconnecting...")
# NOTE: see explanation of HACK, below
self._disconnecting = True
# Clean up client
if self._client:
await self._send_session_end_events()
self._client = None
# Clean up stream
if self._stream:
await self._stream.input_stream.close()
self._stream = None
# NOTE: see explanation of HACK, below
await asyncio.sleep(1)
# Clean up receive task
# HACK: we should ideally be able to cancel the receive task before stopping the input
# stream, above (meaning we wouldn't need self._disconnecting). But for some reason if
# we don't close the input stream and wait a second first, we're getting an error a lot
# like this one: https://github.com/awslabs/amazon-transcribe-streaming-sdk/issues/61.
if self._receive_task:
await self.cancel_task(self._receive_task, timeout=1.0)
self._receive_task = None
# Reset remaining connection-specific state
self._prompt_name = None
self._input_audio_content_name = None
self._content_being_received = None
self._assistant_is_responding = False
self._ready_to_send_context = False
self._handling_bot_stopped_speaking = False
self._triggering_assistant_response = False
self._disconnecting = False
self._connected_time = None
logger.info("Finished disconnecting")
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
def _create_client(self) -> BedrockRuntimeClient:
config = Config(
endpoint_uri=f"https://bedrock-runtime.{self._region}.amazonaws.com",
region=self._region,
aws_credentials_identity_resolver=StaticCredentialsResolver(
credentials=AWSCredentialsIdentity(
access_key_id=self._access_key_id, secret_access_key=self._secret_access_key
)
),
http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
)
return BedrockRuntimeClient(config=config)
#
# LLM communication: input events (pipecat -> LLM)
#
async def _send_session_start_event(self):
session_start = f"""
{{
"event": {{
"sessionStart": {{
"inferenceConfiguration": {{
"maxTokens": {self._params.max_tokens},
"topP": {self._params.top_p},
"temperature": {self._params.temperature}
}}
}}
}}
}}
"""
await self._send_client_event(session_start)
async def _send_prompt_start_event(self, tools: List[Any]):
if not self._prompt_name:
return
tools_config = (
f""",
"toolUseOutputConfiguration": {{
"mediaType": "application/json"
}},
"toolConfiguration": {{
"tools": {json.dumps(tools)}
}}
"""
if tools
else ""
)
prompt_start = f'''
{{
"event": {{
"promptStart": {{
"promptName": "{self._prompt_name}",
"textOutputConfiguration": {{
"mediaType": "text/plain"
}},
"audioOutputConfiguration": {{
"mediaType": "audio/lpcm",
"sampleRateHertz": {self._params.output_sample_rate},
"sampleSizeBits": {self._params.output_sample_size},
"channelCount": {self._params.output_channel_count},
"voiceId": "{self._voice_id}",
"encoding": "base64",
"audioType": "SPEECH"
}}{tools_config}
}}
}}
}}
'''
await self._send_client_event(prompt_start)
async def _send_audio_input_start_event(self):
if not self._prompt_name:
return
audio_content_start = f'''
{{
"event": {{
"contentStart": {{
"promptName": "{self._prompt_name}",
"contentName": "{self._input_audio_content_name}",
"type": "AUDIO",
"interactive": true,
"role": "USER",
"audioInputConfiguration": {{
"mediaType": "audio/lpcm",
"sampleRateHertz": {self._params.input_sample_rate},
"sampleSizeBits": {self._params.input_sample_size},
"channelCount": {self._params.input_channel_count},
"audioType": "SPEECH",
"encoding": "base64"
}}
}}
}}
}}
'''
await self._send_client_event(audio_content_start)
async def _send_text_event(self, text: str, role: Role):
if not self._stream or not self._prompt_name or not text:
return
content_name = str(uuid.uuid4())
text_content_start = f'''
{{
"event": {{
"contentStart": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}",
"type": "TEXT",
"interactive": true,
"role": "{role.value}",
"textInputConfiguration": {{
"mediaType": "text/plain"
}}
}}
}}
}}
'''
await self._send_client_event(text_content_start)
escaped_text = json.dumps(text) # includes quotes
text_input = f'''
{{
"event": {{
"textInput": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}",
"content": {escaped_text}
}}
}}
}}
'''
await self._send_client_event(text_input)
text_content_end = f'''
{{
"event": {{
"contentEnd": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}"
}}
}}
}}
'''
await self._send_client_event(text_content_end)
async def _send_user_audio_event(self, audio: bytes):
if not self._stream:
return
blob = base64.b64encode(audio)
audio_event = f'''
{{
"event": {{
"audioInput": {{
"promptName": "{self._prompt_name}",
"contentName": "{self._input_audio_content_name}",
"content": "{blob.decode("utf-8")}"
}}
}}
}}
'''
await self._send_client_event(audio_event)
async def _send_session_end_events(self):
if not self._stream or not self._prompt_name:
return
prompt_end = f'''
{{
"event": {{
"promptEnd": {{
"promptName": "{self._prompt_name}"
}}
}}
}}
'''
await self._send_client_event(prompt_end)
session_end = """
{
"event": {
"sessionEnd": {}
}
}
"""
await self._send_client_event(session_end)
async def _send_tool_result(self, tool_call_id, result):
if not self._stream or not self._prompt_name:
return
content_name = str(uuid.uuid4())
result_content_start = f'''
{{
"event": {{
"contentStart": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}",
"interactive": false,
"type": "TOOL",
"role": "TOOL",
"toolResultInputConfiguration": {{
"toolUseId": "{tool_call_id}",
"type": "TEXT",
"textInputConfiguration": {{
"mediaType": "text/plain"
}}
}}
}}
}}
}}
'''
await self._send_client_event(result_content_start)
result_content = json.dumps(
{
"event": {
"toolResult": {
"promptName": self._prompt_name,
"contentName": content_name,
"content": json.dumps(result) if isinstance(result, dict) else result,
}
}
}
)
await self._send_client_event(result_content)
result_content_end = f"""
{{
"event": {{
"contentEnd": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}"
}}
}}
}}
"""
await self._send_client_event(result_content_end)
async def _send_client_event(self, event_json: str):
if not self._stream: # should never happen
return
event = InvokeModelWithBidirectionalStreamInputChunk(
value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
)
await self._stream.input_stream.send(event)
#
# LLM communication: output events (LLM -> pipecat)
#
# Receive events for the session.
# A few different kinds of content can be delivered:
# - Transcription of user audio
# - Tool use
# - Text preview of planned response speech before audio delivered
# - User interruption notification
# - Text of response speech that whose audio was actually delivered
# - Audio of response speech
# Each piece of content is wrapped by "contentStart" and "contentEnd" events. The content is
# delivered sequentially: one piece of content will end before another starts.
# The overall completion is wrapped by "completionStart" and "completionEnd" events.
async def _receive_task_handler(self):
try:
while self._stream and not self._disconnecting:
output = await self._stream.await_output()
result = await output[1].receive()
if result.value and result.value.bytes_:
response_data = result.value.bytes_.decode("utf-8")
json_data = json.loads(response_data)
if "event" in json_data:
event_json = json_data["event"]
if "completionStart" in event_json:
# Handle the LLM completion starting
await self._handle_completion_start_event(event_json)
elif "contentStart" in event_json:
# Handle a piece of content starting
await self._handle_content_start_event(event_json)
elif "textOutput" in event_json:
# Handle text output content
await self._handle_text_output_event(event_json)
elif "audioOutput" in event_json:
# Handle audio output content
await self._handle_audio_output_event(event_json)
elif "toolUse" in event_json:
# Handle tool use
await self._handle_tool_use_event(event_json)
elif "contentEnd" in event_json:
# Handle a piece of content ending
await self._handle_content_end_event(event_json)
elif "completionEnd" in event_json:
# Handle the LLM completion ending
await self._handle_completion_end_event(event_json)
except Exception as e:
logger.error(f"{self} error processing responses: {e}")
if self._wants_connection:
await self.reset_conversation()
async def _handle_completion_start_event(self, event_json):
pass
async def _handle_content_start_event(self, event_json):
content_start = event_json["contentStart"]
type = content_start["type"]
role = content_start["role"]
generation_stage = None
if "additionalModelFields" in content_start:
additional_model_fields = json.loads(content_start["additionalModelFields"])
generation_stage = additional_model_fields.get("generationStage")
# Bookkeeping: track current content being received
content = CurrentContent(
type=ContentType(type),
role=Role(role),
text_stage=TextStage(generation_stage) if generation_stage else None,
text_content=None,
)
self._content_being_received = content
if content.role == Role.ASSISTANT:
if content.type == ContentType.AUDIO:
# Note that an assistant response can comprise of multiple audio blocks
if not self._assistant_is_responding:
# The assistant has started responding.
self._assistant_is_responding = True
await self._report_assistant_response_started()
async def _handle_text_output_event(self, event_json):
if not self._content_being_received: # should never happen
return
content = self._content_being_received
text_content = event_json["textOutput"]["content"]
# Bookkeeping: augment the current content being received with text
# Assumption: only one text content per content block
content.text_content = text_content
async def _handle_audio_output_event(self, event_json):
if not self._content_being_received: # should never happen
return
# Get audio
audio_content = event_json["audioOutput"]["content"]
# Push audio frame
audio = base64.b64decode(audio_content)
frame = TTSAudioRawFrame(
audio=audio,
sample_rate=self._params.output_sample_rate,
num_channels=self._params.output_channel_count,
)
await self.push_frame(frame)
async def _handle_tool_use_event(self, event_json):
if not self._content_being_received or not self._context: # should never happen
return
# Get tool use details
tool_use = event_json["toolUse"]
function_name = tool_use["toolName"]
tool_call_id = tool_use["toolUseId"]
arguments = json.loads(tool_use["content"])
# Call tool function
if self.has_function(function_name):
if function_name in self._functions.keys() or None in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_call_id,
function_name=function_name,
arguments=arguments,
)
else:
raise AWSNovaSonicUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)
async def _handle_content_end_event(self, event_json):
if not self._content_being_received: # should never happen
return
content = self._content_being_received
content_end = event_json["contentEnd"]
stop_reason = content_end["stopReason"]
# Bookkeeping: clear current content being received
self._content_being_received = None
if content.role == Role.ASSISTANT:
if content.type == ContentType.TEXT:
# Ignore non-final text, and the "interrupted" message (which isn't meaningful text)
if content.text_stage == TextStage.FINAL and stop_reason != "INTERRUPTED":
if self._assistant_is_responding:
# Text added to the ongoing assistant response
await self._report_assistant_response_text_added(content.text_content)
elif content.role == Role.USER:
if content.type == ContentType.TEXT:
if content.text_stage == TextStage.FINAL:
# User transcription text added
await self._report_user_transcription_text_added(content.text_content)
async def _handle_completion_end_event(self, event_json):
pass
async def _report_assistant_response_started(self):
logger.debug("Assistant response started")
# Report that the assistant has started their response.
await self.push_frame(LLMFullResponseStartFrame())
# Report that equivalent of TTS (this is a speech-to-speech model) started
await self.push_frame(TTSStartedFrame())
async def _report_assistant_response_text_added(self, text):
if not self._context: # should never happen
return
logger.debug(f"Assistant response text added: {text}")
# Report some text added to the ongoing assistant response
await self.push_frame(LLMTextFrame(text))
# Report some text added to the *equivalent* of TTS (this is a speech-to-speech model)
await self.push_frame(TTSTextFrame(text))
# TODO: this is a (hopefully temporary) HACK. Here we directly manipulate the context rather
# than relying on the frames pushed to the assistant context aggregator. The pattern of
# receiving full-sentence text after the assistant has spoken does not easily fit with the
# Pipecat expectation of chunks of text streaming in while the assistant is speaking.
# Interruption handling was especially challenging. Rather than spend days trying to fit a
# square peg in a round hole, I decided on this hack for the time being. We can most cleanly
# abandon this hack if/when AWS Nova Sonic implements streaming smaller text chunks
# interspersed with audio. Note that when we move away from this hack, we need to make sure
# that on an interruption we avoid sending LLMFullResponseEndFrame, which gets the
# LLMAssistantContextAggregator into a bad state.
self._context.buffer_assistant_text(text)
async def _report_assistant_response_ended(self):
if not self._context: # should never happen
return
logger.debug("Assistant response ended")
# Report that the assistant has finished their response.
await self.push_frame(LLMFullResponseEndFrame())
# Report that equivalent of TTS (this is a speech-to-speech model) stopped.
await self.push_frame(TTSStoppedFrame())
# For an explanation of this hack, see _report_assistant_response_text_added.
self._context.flush_aggregated_assistant_text()
async def _report_user_transcription_text_added(self, text):
if not self._context: # should never happen
return
logger.debug(f"User transcription text added: {text}")
# Manually add new user transcription text to context.
# We can't rely on the user context aggregator to do this since it's upstream from the LLM.
self._context.add_user_transcription_text(text)
# Report that some new user transcription text is available.
if self._send_transcription_frames:
await self.push_frame(
TranscriptionFrame(text=text, user_id="", timestamp=time_now_iso8601())
)
#
# context
#
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AWSNovaSonicContextAggregatorPair:
context.set_llm_adapter(self.get_llm_adapter())
user = AWSNovaSonicUserContextAggregator(context=context, params=user_params)
assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params)
return AWSNovaSonicContextAggregatorPair(user, assistant)
#
# assistant response trigger (HACK)
#
# Class variable
AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION = (
"Start speaking when you hear the user say 'ready', but don't consider that 'ready' to be "
"a meaningful part of the conversation other than as a trigger for you to start speaking."
)
async def trigger_assistant_response(self):
if self._triggering_assistant_response:
return False
self._triggering_assistant_response = True
# Read audio bytes, if we don't already have them cached
if not self._assistant_response_trigger_audio:
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
with wave.open(file_path.open("rb"), "rb") as wav_file:
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
# Send the trigger audio, if we're fully connected and set up
if self._connected_time is not None:
await self._send_assistant_response_trigger()
self._triggering_assistant_response = False
async def _send_assistant_response_trigger(self):
if (
not self._assistant_response_trigger_audio or self._connected_time is None
): # should never happen
return
logger.debug("Sending assistant response trigger...")
chunk_duration = 0.02 # what we might get from InputAudioRawFrame
chunk_size = int(
chunk_duration
* self._params.input_sample_rate
* self._params.input_channel_count
* (self._params.input_sample_size / 8)
) # e.g. 0.02 seconds of 16-bit (2-byte) PCM mono audio at 16kHz is 640 bytes
# Lead with a bit of blank audio, if needed.
# It seems like the LLM can't quite "hear" the first little bit of audio sent on a
# connection.
current_time = time.time()
max_blank_audio_duration = 0.5
blank_audio_duration = (
max_blank_audio_duration - (current_time - self._connected_time)
if self._connected_time is not None
and (current_time - self._connected_time) < max_blank_audio_duration
else None
)
if blank_audio_duration:
logger.debug(
f"Leading assistant response trigger with {blank_audio_duration}s of blank audio"
)
blank_audio_chunk = b"\x00" * chunk_size
num_chunks = int(blank_audio_duration / chunk_duration)
for _ in range(num_chunks):
await self._send_user_audio_event(blank_audio_chunk)
await asyncio.sleep(chunk_duration)
# Send trigger audio
# NOTE: this audio *will* be transcribed and eventually make it into the context. That's OK:
# if we ever need to seed this service again with context it would make sense to include it
# since the instruction (i.e. the "wait for the trigger" instruction) will be part of the
# context as well.
audio_chunks = [
self._assistant_response_trigger_audio[i : i + chunk_size]
for i in range(0, len(self._assistant_response_trigger_audio), chunk_size)
]
for chunk in audio_chunks:
await self._send_user_audio_event(chunk)
await asyncio.sleep(chunk_duration)

View File

@@ -1,217 +0,0 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import copy
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
DataFrame,
Frame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
)
class Role(Enum):
SYSTEM = "SYSTEM"
USER = "USER"
ASSISTANT = "ASSISTANT"
TOOL = "TOOL"
@dataclass
class AWSNovaSonicConversationHistoryMessage:
role: Role # only USER and ASSISTANT
text: str
@dataclass
class AWSNovaSonicConversationHistory:
system_instruction: str = None
messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list)
class AWSNovaSonicLLMContext(OpenAILLMContext):
def __init__(self, messages=None, tools=None, **kwargs):
super().__init__(messages=messages, tools=tools, **kwargs)
self.__setup_local()
def __setup_local(self, system_instruction: str = ""):
self._assistant_text = ""
self._system_instruction = system_instruction
@staticmethod
def upgrade_to_nova_sonic(
obj: OpenAILLMContext, system_instruction: str
) -> "AWSNovaSonicLLMContext":
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext):
obj.__class__ = AWSNovaSonicLLMContext
obj.__setup_local(system_instruction)
return obj
# NOTE: this method has the side-effect of updating _system_instruction from messages
def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory:
history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction)
# Bail if there are no messages
if not self.messages:
return history
messages = copy.deepcopy(self.messages)
# If we have a "system" message as our first message, let's pull that out into "instruction"
if messages[0].get("role") == "system":
system = messages.pop(0)
content = system.get("content")
if isinstance(content, str):
history.system_instruction = content
elif isinstance(content, list):
history.system_instruction = content[0].get("text")
if history.system_instruction:
self._system_instruction = history.system_instruction
# Process remaining messages to fill out conversation history.
# Nova Sonic supports "user" and "assistant" messages in history.
for message in messages:
history_message = self.from_standard_message(message)
if history_message:
history.messages.append(history_message)
return history
def get_messages_for_persistent_storage(self):
messages = super().get_messages_for_persistent_storage()
# If we have a system instruction and messages doesn't already contain it, add it
if self._system_instruction and not (messages and messages[0].get("role") == "system"):
messages.insert(0, {"role": "system", "content": self._system_instruction})
return messages
def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
role = message.get("role")
if message.get("role") == "user" or message.get("role") == "assistant":
content = message.get("content")
if isinstance(message.get("content"), list):
content = ""
for c in message.get("content"):
if c.get("type") == "text":
content += " " + c.get("text")
else:
logger.error(
f"Unhandled content type in context message: {c.get('type')} - {message}"
)
# There won't be content if this is an assistant tool call entry.
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
# history
if content:
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
# Sonic conversation history
def add_user_transcription_text(self, text):
message = {
"role": "user",
"content": [{"type": "text", "text": text}],
}
self.add_message(message)
# logger.debug(f"Context updated (user): {self.get_messages_for_logging()}")
def buffer_assistant_text(self, text):
self._assistant_text += text
# logger.debug(f"Assistant text buffered: {self._assistant_text}")
def flush_aggregated_assistant_text(self):
if not self._assistant_text:
return
message = {
"role": "assistant",
"content": [{"type": "text", "text": self._assistant_text}],
}
self._assistant_text = ""
self.add_message(message)
# logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}")
@dataclass
class AWSNovaSonicMessagesUpdateFrame(DataFrame):
context: AWSNovaSonicLLMContext
class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
async def process_frame(
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
):
await super().process_frame(frame, direction)
# Parent does not push LLMMessagesUpdateFrame
if isinstance(frame, LLMMessagesUpdateFrame):
await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context))
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def process_frame(self, frame: Frame, direction: FrameDirection):
# HACK: For now, disable the context aggregator by making it just pass through all frames
# that the parent handles (except the function call stuff, which we still need).
# For an explanation of this hack, see
# AWSNovaSonicLLMService._report_assistant_response_text_added.
if isinstance(
frame,
(
StartInterruptionFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
TextFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
LLMSetToolChoiceFrame,
UserImageRawFrame,
BotStoppedSpeakingFrame,
),
):
await self.push_frame(frame, direction)
else:
await super().process_frame(frame, direction)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
await super().handle_function_call_result(frame)
# The standard function callback code path pushes the FunctionCallResultFrame from the LLM
# itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side
# context. Let's push a special frame to do that.
await self.push_frame(
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
)
@dataclass
class AWSNovaSonicContextAggregatorPair:
_user: AWSNovaSonicUserContextAggregator
_assistant: AWSNovaSonicAssistantContextAggregator
def user(self) -> AWSNovaSonicUserContextAggregator:
return self._user
def assistant(self) -> AWSNovaSonicAssistantContextAggregator:
return self._assistant

View File

@@ -1,14 +0,0 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from dataclasses import dataclass
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
@dataclass
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
result_frame: FunctionCallResultFrame

View File

@@ -0,0 +1,13 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import sys
from pipecat.services import DeprecatedModuleProxy
from .metrics import *
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "canonical", "canonical.metrics")

View File

@@ -0,0 +1,230 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import io
import os
import uuid
import wave
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import aiohttp
from loguru import logger
from pipecat.frames.frames import CancelFrame, EndFrame, Frame
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
try:
import aiofiles
import aiofiles.os
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Canonical Metrics, you need to `pip install pipecat-ai[canonical]`. "
+ "Also, set the `CANONICAL_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
# Multipart upload part size in bytes, cannot be smaller than 5MB
PART_SIZE = 1024 * 1024 * 5
class CanonicalMetricsService(AIService):
"""Initialize a CanonicalAudioProcessor instance.
This class uses an AudioBufferProcessor to get the conversation audio and
uploads it to Canonical Voice API for audio processing.
Args:
call_id (str): Your unique identifier for the call. This is used to match the call in the Canonical Voice system to the call in your system.
assistant (str): Identifier for the AI assistant. This can be whatever you want, it's intended for you convenience so you can distinguish
between different assistants and a grouping mechanism for calls.
assistant_speaks_first (bool, optional): Indicates if the assistant speaks first in the conversation. Defaults to True.
output_dir (str, optional): Directory to save temporary audio files. Defaults to "recordings".
Attributes:
call_id (str): Stores the unique call identifier.
assistant (str): Stores the assistant identifier.
assistant_speaks_first (bool): Indicates whether the assistant speaks first.
output_dir (str): Directory path for saving temporary audio files.
The constructor also ensures that the output directory exists.
"""
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
call_id: str,
assistant: str,
api_key: str,
api_url: str = "https://voiceapp.canonical.chat/api/v1",
assistant_speaks_first: bool = True,
output_dir: str = "recordings",
audio_buffer_processor: Optional[AudioBufferProcessor] = None,
context: Optional[OpenAILLMContext] = None,
**kwargs,
):
super().__init__(**kwargs)
# Validate that at least one of audio_buffer_processor or context is provided
if audio_buffer_processor is None and context is None:
raise ValueError("At least one of audio_buffer_processor or context must be specified")
self._aiohttp_session = aiohttp_session
self._audio_buffer_processor = audio_buffer_processor
self._api_key = api_key
self._api_url = api_url
self._call_id = call_id
self._assistant = assistant
self._assistant_speaks_first = assistant_speaks_first
self._output_dir = output_dir
self._context = context
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._process_completion()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._process_completion()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
async def _process_completion(self):
if self._audio_buffer_processor is not None:
await self._process_audio()
elif self._context is not None:
await self._process_transcript()
async def _process_transcript(self):
params = {
"callId": self._call_id,
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
"transcript": self._context.messages,
}
response = await self._aiohttp_session.post(
f"{self._api_url}/call",
headers=self._request_headers(),
json=params,
)
if not response.ok:
logger.error(f"Failed to process transcript: {await response.text()}")
async def _process_audio(self):
audio_buffer_processor = self._audio_buffer_processor
if not audio_buffer_processor.has_audio():
return
os.makedirs(self._output_dir, exist_ok=True)
filename = self._get_output_filename()
audio = audio_buffer_processor.merge_audio_buffers()
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wf:
wf.setsampwidth(2)
wf.setnchannels(audio_buffer_processor.num_channels)
wf.setframerate(audio_buffer_processor.sample_rate)
wf.writeframes(audio)
async with aiofiles.open(filename, "wb") as file:
await file.write(buffer.getvalue())
try:
await self._multipart_upload(filename)
await aiofiles.os.remove(filename)
except FileNotFoundError:
pass
except Exception as e:
logger.error(f"Failed to upload recording: {e}")
def _get_output_filename(self):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{self._output_dir}/{timestamp}-{uuid.uuid4().hex}.wav"
def _request_headers(self):
return {"Content-Type": "application/json", "X-Canonical-Api-Key": self._api_key}
async def _multipart_upload(self, file_path: str):
upload_request, upload_response = await self._request_upload(file_path)
if upload_request is None or upload_response is None:
return
parts = await self._upload_parts(file_path, upload_response)
if parts is None:
return
await self._upload_complete(parts, upload_request, upload_response)
async def _request_upload(self, file_path: str) -> Tuple[Dict, Dict]:
filename = os.path.basename(file_path)
filesize = os.path.getsize(file_path)
numparts = int((filesize + PART_SIZE - 1) / PART_SIZE)
params = {
"filename": filename,
"parts": numparts,
"callId": self._call_id,
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
}
logger.debug(f"Requesting presigned URLs for {numparts} parts")
response = await self._aiohttp_session.post(
f"{self._api_url}/recording/uploadRequest", headers=self._request_headers(), json=params
)
if not response.ok:
logger.error(f"Failed to get presigned URLs: {await response.text()}")
return None, None
response_json = await response.json()
return params, response_json
async def _upload_parts(self, file_path: str, upload_response: Dict) -> List[Dict]:
urls = upload_response["urls"]
parts = []
try:
async with aiofiles.open(file_path, "rb") as file:
for partnum, upload_url in enumerate(urls, start=1):
data = await file.read(PART_SIZE)
if not data:
break
response = await self._aiohttp_session.put(upload_url, data=data)
if not response.ok:
logger.error(f"Failed to upload part {partnum}: {await response.text()}")
return None
etag = response.headers["ETag"]
parts.append({"partnum": str(partnum), "etag": etag})
except Exception as e:
logger.error(f"Multipart upload aborted, an error occurred: {str(e)}")
return parts
async def _upload_complete(
self, parts: List[Dict], upload_request: Dict, upload_response: Dict
):
params = {
"filename": upload_request["filename"],
"parts": parts,
"slug": upload_response["slug"],
"callId": self._call_id,
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
}
if self._context is not None:
params["transcript"] = self._context.messages
logger.debug(f"Completing upload for {params['filename']}")
logger.debug(f"Slug: {params['slug']}")
response = await self._aiohttp_session.post(
f"{self._api_url}/recording/uploadComplete",
headers=self._request_headers(),
json=params,
)
if not response.ok:
logger.error(f"Failed to complete upload: {await response.text()}")
return

View File

@@ -30,7 +30,7 @@ class DeepgramTTSService(TTSService):
self,
*,
api_key: str,
voice: str = "aura-2-helena-en",
voice: str = "aura-helios-en",
base_url: str = "",
sample_rate: Optional[int] = None,
encoding: str = "linear16",

View File

@@ -7,12 +7,11 @@
import asyncio
import base64
import json
import uuid
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union
import aiohttp
from loguru import logger
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from pipecat.frames.frames import (
CancelFrame,
@@ -27,10 +26,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import (
AudioContextWordTTSService,
WordTTSService,
)
from pipecat.services.tts_service import InterruptibleWordTTSService, WordTTSService
from pipecat.transcriptions.language import Language
# See .env.example for ElevenLabs configuration needed
@@ -163,17 +159,26 @@ def calculate_word_times(
return word_times
class ElevenLabsTTSService(AudioContextWordTTSService):
class ElevenLabsTTSService(InterruptibleWordTTSService):
class InputParams(BaseModel):
language: Optional[Language] = None
optimize_streaming_latency: Optional[str] = None
stability: Optional[float] = None
similarity_boost: Optional[float] = None
style: Optional[float] = None
use_speaker_boost: Optional[bool] = None
speed: Optional[float] = None
auto_mode: Optional[bool] = True
enable_ssml_parsing: Optional[bool] = None
enable_logging: Optional[bool] = None
@model_validator(mode="after")
def validate_voice_settings(self):
stability = self.stability
similarity_boost = self.similarity_boost
if (stability is None) != (similarity_boost is None):
raise ValueError(
"Both 'stability' and 'similarity_boost' must be provided when using voice settings"
)
return self
def __init__(
self,
@@ -215,14 +220,13 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
"language": self.language_to_service_language(params.language)
if params.language
else None,
"optimize_streaming_latency": params.optimize_streaming_latency,
"stability": params.stability,
"similarity_boost": params.similarity_boost,
"style": params.style,
"use_speaker_boost": params.use_speaker_boost,
"speed": params.speed,
"auto_mode": str(params.auto_mode).lower(),
"enable_ssml_parsing": params.enable_ssml_parsing,
"enable_logging": params.enable_logging,
}
self.set_model_name(model)
self.set_voice(voice_id)
@@ -234,8 +238,6 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
self._started = False
self._cumulative_time = 0
# Context management for v1 multi API
self._context_id = None
self._receive_task = None
self._keepalive_task = None
@@ -251,13 +253,15 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
async def set_model(self, model: str):
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")
# No need to disconnect/reconnect for model changes with multi-context API
await self._disconnect()
await self._connect()
async def _update_settings(self, settings: Mapping[str, Any]):
prev_voice = self._voice_id
await super()._update_settings(settings)
# If voice changes, we don't need to reconnect, just use a new context
if not prev_voice == self._voice_id:
await self._disconnect()
await self._connect()
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
async def start(self, frame: StartFrame):
@@ -274,8 +278,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._disconnect()
async def flush_audio(self):
if self._websocket and self._context_id:
msg = {"context_id": self._context_id, "flush": True}
if self._websocket:
msg = {"text": " ", "flush": True}
await self._websocket.send(json.dumps(msg))
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -315,13 +319,10 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
voice_id = self._voice_id
model = self.model_name
output_format = self._output_format
url = f"{self._url}/v1/text-to-speech/{voice_id}/multi-stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
if self._settings["enable_ssml_parsing"]:
url += f"&enable_ssml_parsing={self._settings['enable_ssml_parsing']}"
if self._settings["enable_logging"]:
url += f"&enable_logging={self._settings['enable_logging']}"
if self._settings["optimize_streaming_latency"]:
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
# Language can only be used with the ELEVENLABS_MULTILINGUAL_MODELS
language = self._settings["language"]
@@ -336,6 +337,14 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
# Set max websocket message size to 16MB for large audio responses
self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024)
# According to ElevenLabs, we should always start with a single space.
msg: Dict[str, Any] = {
"text": " ",
"xi_api_key": self._api_key,
}
if self._voice_settings:
msg["voice_settings"] = self._voice_settings
await self._websocket.send(json.dumps(msg))
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
@@ -347,15 +356,12 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
if self._websocket:
logger.debug("Disconnecting from ElevenLabs")
# Close all contexts and the socket
if self._context_id:
await self._websocket.send(json.dumps({"close_socket": True}))
await self._websocket.send(json.dumps({"text": ""}))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
finally:
self._started = False
self._context_id = None
self._websocket = None
def _get_websocket(self):
@@ -363,35 +369,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
# Close the current context when interrupted without closing the websocket
if self._context_id and self._websocket:
logger.trace(f"Closing context {self._context_id} due to interruption")
try:
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True})
)
except Exception as e:
logger.error(f"Error closing context on interruption: {e}")
self._context_id = None
self._started = False
async def _receive_messages(self):
async for message in self._get_websocket():
msg = json.loads(message)
# Check if this message belongs to the current context
# The default context may return null/None for context_id
received_ctx_id = msg.get("context_id")
if (
self._context_id is not None
and received_ctx_id is not None
and received_ctx_id != self._context_id
):
logger.trace(f"Ignoring message from different context: {received_ctx_id}")
continue
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()
@@ -403,45 +383,20 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
if msg.get("is_final"):
logger.trace(f"Received final message for context {received_ctx_id}")
# Context has finished
if self._context_id == received_ctx_id:
self._context_id = None
self._started = False
async def _keepalive_task_handler(self):
while True:
await asyncio.sleep(10)
try:
# Send an empty message to keep the connection alive
if self._websocket and self._websocket.open:
await self._websocket.send(json.dumps({}))
await self._send_text("")
except websockets.ConnectionClosed as e:
logger.warning(f"{self} keepalive error: {e}")
break
async def _send_text(self, text: str):
if self._websocket:
if not self._context_id:
# First message for a new context - need a space to initialize
msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key}
# Add voice settings only in first message for a context
if self._voice_settings:
msg["voice_settings"] = self._voice_settings
await self._websocket.send(json.dumps(msg))
self._context_id = msg["context_id"]
logger.trace(f"Created new context {self._context_id}")
# Now send the actual text content
msg = {"text": text, "context_id": self._context_id}
await self._websocket.send(json.dumps(msg))
else:
# Continuing with an existing context
msg = {"text": text, "context_id": self._context_id}
await self._websocket.send(json.dumps(msg))
msg = {"text": text + " "}
await self._websocket.send(json.dumps(msg))
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
@@ -451,13 +406,6 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._connect()
try:
# Close previous context if there was one
if self._context_id and not self._started:
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True})
)
self._context_id = None
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
@@ -469,8 +417,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
except Exception as e:
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
self._started = False
self._context_id = None
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:

View File

@@ -190,7 +190,6 @@ class LLMService(AIService):
function_name: Optional[str] = None,
tool_call_id: Optional[str] = None,
text_content: Optional[str] = None,
video_source: Optional[str] = None,
):
await self.push_frame(
UserImageRequestFrame(
@@ -198,7 +197,6 @@ class LLMService(AIService):
function_name=function_name,
tool_call_id=tool_call_id,
context=text_content,
video_source=video_source,
),
FrameDirection.UPSTREAM,
)

View File

@@ -577,7 +577,15 @@ class OpenAIRealtimeBetaLLMService(LLMService):
arguments = json.loads(item.arguments)
if self.has_function(function_name):
run_llm = index == total_items - 1
if function_name in self._functions.keys() or None in self._functions.keys():
if function_name in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_id,
function_name=function_name,
arguments=arguments,
run_llm=run_llm,
)
elif None in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_id,

View File

@@ -122,7 +122,6 @@ class BaseInputTransport(FrameProcessor):
# Configure VAD analyzer.
if self._params.vad_analyzer:
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
# Configure End of turn analyzer.
if self._params.turn_analyzer:
self._params.turn_analyzer.set_sample_rate(self._sample_rate)
@@ -130,6 +129,10 @@ class BaseInputTransport(FrameProcessor):
# Start audio filter.
if self._params.audio_in_filter:
await self._params.audio_in_filter.start(self._sample_rate)
# Create audio input queue and task if needed.
if not self._audio_task and self._params.audio_in_enabled:
self._audio_in_queue = asyncio.Queue()
self._audio_task = self.create_task(self._audio_task_handler())
async def stop(self, frame: EndFrame):
# Cancel and wait for the audio input task to finish.
@@ -146,13 +149,6 @@ class BaseInputTransport(FrameProcessor):
await self.cancel_task(self._audio_task)
self._audio_task = None
async def set_transport_ready(self, frame: StartFrame):
"""To be called when the transport is ready to stream."""
# Create audio input queue and task if needed.
if not self._audio_task and self._params.audio_in_enabled:
self._audio_in_queue = asyncio.Queue()
self._audio_task = self.create_task(self._audio_task_handler())
async def push_audio_frame(self, frame: InputAudioRawFrame):
if self._params.audio_in_enabled:
await self._audio_in_queue.put(frame)

View File

@@ -78,16 +78,6 @@ class BaseOutputTransport(FrameProcessor):
audio_bytes_10ms = int(self._sample_rate / 100) * self._params.audio_out_channels * 2
self._audio_chunk_size = audio_bytes_10ms * self._params.audio_out_10ms_chunks
async def stop(self, frame: EndFrame):
for _, sender in self._media_senders.items():
await sender.stop(frame)
async def cancel(self, frame: CancelFrame):
for _, sender in self._media_senders.items():
await sender.cancel(frame)
async def set_transport_ready(self, frame: StartFrame):
"""To be called when the transport is ready to stream."""
# Register destinations.
for destination in self._params.audio_out_destinations:
await self.register_audio_destination(destination)
@@ -122,6 +112,14 @@ class BaseOutputTransport(FrameProcessor):
)
await self._media_senders[destination].start(frame)
async def stop(self, frame: EndFrame):
for _, sender in self._media_senders.items():
await sender.stop(frame)
async def cancel(self, frame: CancelFrame):
for _, sender in self._media_senders.items():
await sender.cancel(frame)
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
pass

View File

@@ -61,8 +61,6 @@ class LocalAudioInputTransport(BaseInputTransport):
)
self._in_stream.start_stream()
await self.set_transport_ready(frame)
async def cleanup(self):
await super().cleanup()
if self._in_stream:
@@ -113,8 +111,6 @@ class LocalAudioOutputTransport(BaseOutputTransport):
)
self._out_stream.start_stream()
await self.set_transport_ready(frame)
async def cleanup(self):
await super().cleanup()
if self._out_stream:

View File

@@ -68,8 +68,6 @@ class TkInputTransport(BaseInputTransport):
)
self._in_stream.start_stream()
await self.set_transport_ready(frame)
async def cleanup(self):
await super().cleanup()
if self._in_stream:
@@ -126,8 +124,6 @@ class TkOutputTransport(BaseOutputTransport):
)
self._out_stream.start_stream()
await self.set_transport_ready(frame)
async def cleanup(self):
await super().cleanup()
if self._out_stream:

View File

@@ -131,7 +131,6 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
await self._client.trigger_client_connected()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_messages())
await self.set_transport_ready(frame)
async def _stop_tasks(self):
if self._monitor_websocket_task:
@@ -205,7 +204,6 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
await self._client.setup(frame)
await self._params.serializer.setup(frame)
self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)

View File

@@ -395,7 +395,6 @@ class SmallWebRTCInputTransport(BaseInputTransport):
self._receive_audio_task = self.create_task(self._receive_audio())
if not self._receive_video_task and self._params.video_in_enabled:
self._receive_video_task = self.create_task(self._receive_video())
await self.set_transport_ready(frame)
async def _stop_tasks(self):
if self._receive_audio_task:
@@ -488,7 +487,6 @@ class SmallWebRTCOutputTransport(BaseOutputTransport):
await super().start(frame)
await self._client.setup(self._params, frame)
await self._client.connect()
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)

View File

@@ -136,7 +136,6 @@ class WebsocketClientInputTransport(BaseInputTransport):
await self._params.serializer.setup(frame)
await self._session.setup(frame)
await self._session.connect()
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)
@@ -187,7 +186,6 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
await self._params.serializer.setup(frame)
await self._session.setup(frame)
await self._session.connect()
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)

View File

@@ -83,7 +83,6 @@ class WebsocketServerInputTransport(BaseInputTransport):
await self._params.serializer.setup(frame)
if not self._server_task:
self._server_task = self.create_task(self._server_task_handler())
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)
@@ -196,7 +195,6 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
await super().start(frame)
await self._params.serializer.setup(frame)
self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)

View File

@@ -175,7 +175,6 @@ class DailyCallbacks(BaseModel):
"""Callback handlers for Daily events.
Attributes:
on_active_speaker_changed: Called when the active speaker of the call has changed.
on_joined: Called when bot successfully joined a room.
on_left: Called when bot left a room.
on_error: Called when an error occurs.
@@ -202,7 +201,6 @@ class DailyCallbacks(BaseModel):
on_recording_error: Called when recording encounters an error.
"""
on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]]
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
on_left: Callable[[], Awaitable[None]]
on_error: Callable[[str], Awaitable[None]]
@@ -700,7 +698,7 @@ class DailyTransportClient(EventHandler):
await self.update_subscriptions(participant_settings={participant_id: media})
self._audio_renderers.setdefault(participant_id, {})[audio_source] = callback
self._audio_renderers[participant_id] = {audio_source: callback}
self._client.set_audio_renderer(
participant_id,
@@ -724,7 +722,7 @@ class DailyTransportClient(EventHandler):
await self.update_subscriptions(participant_settings={participant_id: media})
self._video_renderers.setdefault(participant_id, {})[video_source] = callback
self._video_renderers[participant_id] = {video_source: callback}
self._client.set_video_renderer(
participant_id,
@@ -791,9 +789,6 @@ class DailyTransportClient(EventHandler):
# Daily (EventHandler)
#
def on_active_speaker_changed(self, participant):
self._call_async_callback(self._callbacks.on_active_speaker_changed, participant)
def on_app_message(self, message: Any, sender: str):
self._call_async_callback(self._callbacks.on_app_message, message, sender)
@@ -949,23 +944,19 @@ class DailyInputTransport(BaseInputTransport):
self._audio_in_task = self.create_task(self._audio_in_task_handler())
async def start(self, frame: StartFrame):
# Setup client.
await self._client.setup(frame)
# Parent start.
await super().start(frame)
if self._initialized:
return
self._initialized = True
# Parent start.
await super().start(frame)
# Setup client.
await self._client.setup(frame)
# Join the room.
await self._client.join()
# Indicate the transport that we are connected.
await self.set_transport_ready(frame)
if self._params.audio_in_stream_on_start:
self.start_audio_in_streaming()
@@ -1061,13 +1052,12 @@ class DailyInputTransport(BaseInputTransport):
video_source: str = "camera",
color_format: str = "RGB",
):
if participant_id not in self._video_renderers:
self._video_renderers[participant_id] = {}
self._video_renderers[participant_id][video_source] = {
"framerate": framerate,
"timestamp": 0,
"render_next_frame": [],
self._video_renderers[participant_id] = {
video_source: {
"framerate": framerate,
"timestamp": 0,
"render_next_frame": [],
}
}
await self._client.capture_participant_video(
@@ -1076,8 +1066,7 @@ class DailyInputTransport(BaseInputTransport):
async def request_participant_image(self, frame: UserImageRequestFrame):
if frame.user_id in self._video_renderers:
video_source = frame.video_source if frame.video_source else "camera"
self._video_renderers[frame.user_id][video_source]["render_next_frame"].append(frame)
self._video_renderers[frame.user_id]["render_next_frame"].append(frame)
async def _on_participant_video_frame(
self, participant_id: str, video_frame: VideoFrame, video_source: str
@@ -1136,23 +1125,20 @@ class DailyOutputTransport(BaseOutputTransport):
self._initialized = False
async def start(self, frame: StartFrame):
# Setup client.
await self._client.setup(frame)
# Parent start.
await super().start(frame)
if self._initialized:
return
self._initialized = True
# Parent start.
await super().start(frame)
# Setup client.
await self._client.setup(frame)
# Join the room.
await self._client.join()
# Indicate the transport that we are connected.
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
# Parent stop.
await super().stop(frame)
@@ -1215,7 +1201,6 @@ class DailyTransport(BaseTransport):
super().__init__(input_name=input_name, output_name=output_name)
callbacks = DailyCallbacks(
on_active_speaker_changed=self._on_active_speaker_changed,
on_joined=self._on_joined,
on_left=self._on_left,
on_error=self._on_error,
@@ -1251,7 +1236,6 @@ class DailyTransport(BaseTransport):
# Register supported handlers. The user will only be able to register
# these handlers.
self._register_event_handler("on_active_speaker_changed")
self._register_event_handler("on_joined")
self._register_event_handler("on_left")
self._register_event_handler("on_error")
@@ -1386,9 +1370,6 @@ class DailyTransport(BaseTransport):
async def update_remote_participants(self, remote_participants: Mapping[str, Any]):
await self._client.update_remote_participants(remote_participants=remote_participants)
async def _on_active_speaker_changed(self, participant: Any):
await self._call_event_handler("on_active_speaker_changed", participant)
async def _on_joined(self, data):
await self._call_event_handler("on_joined", data)

View File

@@ -370,7 +370,6 @@ class LiveKitInputTransport(BaseInputTransport):
await self._client.connect()
if not self._audio_in_task and self._params.audio_in_enabled:
self._audio_in_task = self.create_task(self._audio_in_task_handler())
await self.set_transport_ready(frame)
logger.info("LiveKitInputTransport started")
async def stop(self, frame: EndFrame):
@@ -442,7 +441,6 @@ class LiveKitOutputTransport(BaseOutputTransport):
await super().start(frame)
await self._client.setup(frame)
await self._client.connect()
await self.set_transport_ready(frame)
logger.info("LiveKitOutputTransport started")
async def stop(self, frame: EndFrame):

View File

@@ -6,7 +6,7 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Coroutine, Dict, Optional, Sequence, Set
from typing import Coroutine, Optional, Set
from loguru import logger
@@ -69,14 +69,14 @@ class BaseTaskManager(ABC):
pass
@abstractmethod
def current_tasks(self) -> Sequence[asyncio.Task]:
def current_tasks(self) -> Set[asyncio.Task]:
"""Returns the list of currently created/registered tasks."""
pass
class TaskManager(BaseTaskManager):
def __init__(self) -> None:
self._tasks: Dict[str, asyncio.Task] = {}
self._tasks: Set[asyncio.Task] = set()
self._loop: Optional[asyncio.AbstractEventLoop] = None
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
@@ -179,17 +179,16 @@ class TaskManager(BaseTaskManager):
finally:
self._remove_task(task)
def current_tasks(self) -> Sequence[asyncio.Task]:
def current_tasks(self) -> Set[asyncio.Task]:
"""Returns the list of currently created/registered tasks."""
return list(self._tasks.values())
return self._tasks
def _add_task(self, task: asyncio.Task):
name = task.get_name()
self._tasks[name] = task
self._tasks.add(task)
def _remove_task(self, task: asyncio.Task):
name = task.get_name()
try:
del self._tasks[name]
self._tasks.remove(task)
except KeyError as e:
logger.trace(f"{name}: unable to remove task (already removed?): {e}")

View File

@@ -1 +1 @@
-e ".[anthropic,aws,google,langchain]"
-e ".[anthropic,google,langchain]"

View File

@@ -40,11 +40,6 @@ from pipecat.services.anthropic.llm import (
AnthropicLLMContext,
AnthropicUserContextAggregator,
)
from pipecat.services.aws.llm import (
AWSBedrockAssistantContextAggregator,
AWSBedrockLLMContext,
AWSBedrockUserContextAggregator,
)
from pipecat.services.google.llm import (
GoogleAssistantContextAggregator,
GoogleLLMContext,
@@ -674,6 +669,26 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola
AGGREGATOR_CLASS = LLMUserContextAggregator
#
# OpenAI
#
class TestOpenAIUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIUserContextAggregator
class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
#
# Anthropic
#
@@ -709,43 +724,6 @@ class TestAnthropicAssistantContextAggregator(
assert context.messages[index]["content"][0]["content"] == json.dumps(content)
#
# AWS (Bedrock)
#
class TestAWSBedrockUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AWSBedrockLLMContext
AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
class TestAWSBedrockAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AWSBedrockLLMContext
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
assert context.messages[index]["content"][0]["toolResult"]["content"][0][
"text"
] == json.dumps(content)
#
# Google
#
@@ -788,23 +766,3 @@ class TestGoogleAssistantContextAggregator(
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)
#
# OpenAI
#
class TestOpenAIUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIUserContextAggregator
class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]

View File

@@ -9,7 +9,6 @@ import unittest
from pipecat.frames.frames import (
EndFrame,
Frame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -58,8 +57,8 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
async def test_system_frame(self):
filter = FrameFilter(types=())
frames_to_send = [StartInterruptionFrame()]
expected_down_frames = [StartInterruptionFrame]
frames_to_send = [UserStartedSpeakingFrame()]
expected_down_frames = [UserStartedSpeakingFrame]
await run_test(
filter,
frames_to_send=frames_to_send,

View File

@@ -11,7 +11,6 @@ from openai.types.chat import ChatCompletionToolParam
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
@@ -175,32 +174,3 @@ class TestFunctionAdapters(unittest.TestCase):
tools_def = self.tools_def
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected
def test_bedrock_adapter(self):
"""Test AWS Bedrock adapter format transformation."""
expected = [
{
"toolSpec": {
"name": "get_weather",
"description": "Get the weather in a given location",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use.",
},
"location": {
"type": "string",
"description": "The city, e.g. San Francisco",
},
},
"required": ["location", "format"],
}
},
}
}
]
assert AWSBedrockLLMAdapter().to_provider_tools_format(self.tools_def) == expected