diff --git a/CHANGELOG.md b/CHANGELOG.md index 54cbd3cd6..9b85c7f87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,16 +5,71 @@ All notable changes to **Pipecat** will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [0.0.67] - 2025-05-07 ### Added +- Added `DebugLogObserver` for detailed frame logging with configurable + filtering by frame type and endpoint. This observer automatically extracts + and formats all frame data fields for debug logging. + +- `UserImageRequestFrame.video_source` field has been added to request an image + from the desired video source. + +- Added support for the AWS Nova Sonic speech-to-speech model with the new + `AWSNovaSonicLLMService`. + See https://docs.aws.amazon.com/nova/latest/userguide/speech.html. + Note that it requires Python >= 3.12 and `pip install pipecat-ai[aws-nova-sonic]`. + +- Added new AWS services `AWSBedrockLLMService` and `AWSTranscribeSTTService`. + +- Added `on_active_speaker_changed` event handler to the `DailyTransport` class. + +- Added `enable_ssml_parsing` and `enable_logging` to `InputParams` in + `ElevenLabsTTSService`. + - Added support to `RimeHttpTTSService` for the `arcana` model. +### Changed + +- Updated `ElevenLabsTTSService` to use the beta websocket API + (multi-stream-input). This new API supports context_ids and cancelling those + contexts, which greatly improves interruption handling. + +- Observers `on_push_frame()` now take a single argument `FramePushed` instead + of multiple arguments. + +- Updated the default voice for `DeepgramTTSService` to `aura-2-helena-en`. + +### Deprecated + +- `PollyTTSService` is now deprecated, use `AWSPollyTTSService` instead. + +- Observer `on_push_frame(src, dst, frame, direction, timestamp)` is now + deprecated, use `on_push_frame(data: FramePushed)` instead. + ### Fixed +- Fixed a `DailyTransport` issue that was causing issues when multiple audio or + video sources where being captured. + +- Fixed a `UltravoxSTTService` issue that would cause the service to generate + all tokens as one word. + +- Fixed a `PipelineTask` issue that would cause tasks to not be cancelled if + task was cancelled from outside of Pipecat. + +- Fixed a `TaskManager` that was causing dangling tasks to be reported. + +- Fixed an issue that could cause data to be sent to the transports when they + were still not ready. + - Remove custom audio tracks from `DailyTransport` before leaving. +### Removed + +- Removed `CanonicalMetricsService` as it's no longer maintained. + ## [0.0.66] - 2025-05-02 ### Added diff --git a/README.md b/README.md index ac2444f87..ec3b0a791 100644 --- a/README.md +++ b/README.md @@ -49,18 +49,18 @@ You can connect to Pipecat from any platform using our official SDKs: ## 🧩 Available services -| Category | Services | -| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | -| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | -| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | -| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | -| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | -| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | -| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) | -| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | -| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | -| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | +| Category | Services | +|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | +| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | +| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | +| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | +| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | +| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | +| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) | +| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | +| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | +| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | 📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services) diff --git a/docs/api/requirements.txt b/docs/api/requirements.txt index 9badccd8f..a77ff1084 100644 --- a/docs/api/requirements.txt +++ b/docs/api/requirements.txt @@ -10,7 +10,6 @@ pipecat-ai[anthropic] pipecat-ai[assemblyai] pipecat-ai[aws] pipecat-ai[azure] -pipecat-ai[canonical] pipecat-ai[cartesia] pipecat-ai[cerebras] pipecat-ai[deepseek] diff --git a/examples/canonical-metrics/.gitignore b/examples/canonical-metrics/.gitignore deleted file mode 100644 index 50d9d205e..000000000 --- a/examples/canonical-metrics/.gitignore +++ /dev/null @@ -1,161 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class -recordings/ -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ -runpod.toml diff --git a/examples/canonical-metrics/Dockerfile b/examples/canonical-metrics/Dockerfile deleted file mode 100644 index a5b4668c6..000000000 --- a/examples/canonical-metrics/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM python:3.10-bullseye -RUN mkdir /app -COPY *.py /app/ -COPY requirements.txt /app/ -WORKDIR /app -RUN pip3 install -r requirements.txt - -EXPOSE 7860 - -CMD ["python3", "server.py"] diff --git a/examples/canonical-metrics/README.md b/examples/canonical-metrics/README.md deleted file mode 100644 index 068655d2b..000000000 --- a/examples/canonical-metrics/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Chatbot with canonical-metrics - -This project implements a chatbot using a pipeline architecture that integrates audio processing, transcription, and a language model for conversational interactions. The chatbot operates within a daily communication environment, utilizing various services for text-to-speech and language model responses. - -## Features - -- **Audio Input and Output**: Captures microphone input and plays back audio responses. -- **Voice Activity Detection**: Utilizes Silero VAD to manage audio input intelligently. -- **Text-to-Speech**: Integrates ElevenLabs TTS service to convert text responses into audio. -- **Language Model Interaction**: Uses OpenAI's GPT-4 model to generate responses based on user input. -- **Transcription Services**: Captures and transcribes participant speech for analytics. -- **Metrics Collection**: Sends audio data for analysis via Canonical Metrics Service. - -## Requirements - -- Python 3.10+ -- `python-dotenv` -- Additional libraries from the `pipecat` package. - -## Setup - -1. Clone the repository. -2. Install the required packages. -3. Set up environment variables for API keys: - - `OPENAI_API_KEY` - - `ELEVENLABS_API_KEY` - - `CANONICAL_API_KEY` - - `CANONICAL_API_URL` -4. Run the script. - -## Usage - -The chatbot introduces itself and engages in conversations, providing brief and creative responses. Designed for flexibility, it can support multiple languages with appropriate configuration. - -## Events - -- Participants joining or leaving the call are handled dynamically, adjusting the chatbot's behavior accordingly. - - -ℹ️ The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded. - -## Get started - -```python -python3 -m venv venv -source venv/bin/activate -pip install -r requirements.txt - -cp env.example .env # and add your credentials - -``` - -## Run the server - -```bash -python server.py -``` - -Then, visit `http://localhost:7860/` in your browser to start a chatbot session. - -## Build and test the Docker image - -``` -docker build -t chatbot . -docker run --env-file .env -p 7860:7860 chatbot -``` diff --git a/examples/canonical-metrics/bot.py b/examples/canonical-metrics/bot.py deleted file mode 100644 index 871d0542d..000000000 --- a/examples/canonical-metrics/bot.py +++ /dev/null @@ -1,146 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import asyncio -import os -import sys -import uuid - -import aiohttp -from dotenv import load_dotenv -from loguru import logger -from runner import configure - -from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import EndFrame -from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor -from pipecat.services.canonical.metrics import CanonicalMetricsService -from pipecat.services.elevenlabs.tts import ElevenLabsTTSService -from pipecat.services.openai.llm import OpenAILLMService -from pipecat.transports.services.daily import DailyParams, DailyTransport - -load_dotenv(override=True) - -logger.remove(0) -logger.add(sys.stderr, level="DEBUG") - - -async def main(): - async with aiohttp.ClientSession() as session: - (room_url, token) = await configure(session) - - transport = DailyTransport( - room_url, - token, - "Chatbot", - DailyParams( - audio_out_enabled=True, - audio_in_enabled=True, - video_out_enabled=False, - vad_analyzer=SileroVADAnalyzer(), - transcription_enabled=True, - # - # Spanish - # - # transcription_settings=DailyTranscriptionSettings( - # language="es", - # tier="nova", - # model="2-general" - # ) - ), - ) - - tts = ElevenLabsTTSService( - api_key=os.getenv("ELEVENLABS_API_KEY"), - # - # English - # - voice_id="cgSgspJ2msm6clMCkdW9", - # - # Spanish - # - # model="eleven_multilingual_v2", - # voice_id="gD1IexrzCvsXPHUuT0s3", - ) - - llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) - - messages = [ - { - "role": "system", - # - # English - # - "content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself. Keep all your responses to 12 words or fewer.", - # - # Spanish - # - # "content": "Eres Chatbot, un amigable y útil robot. Tu objetivo es demostrar tus capacidades de una manera breve. Tus respuestas se convertiran a audio así que nunca no debes incluir caracteres especiales. Contesta a lo que el usuario pregunte de una manera creativa, útil y breve. Empieza por presentarte a ti mismo.", - }, - ] - - context = OpenAILLMContext(messages) - context_aggregator = llm.create_context_aggregator(context) - - """ - CanonicalMetrics uses AudioBufferProcessor under the hood to buffer the audio. On - call completion, CanonicalMetrics will send the audio buffer to Canonical for - analysis. Visit https://voice.canonical.chat to learn more. - """ - audio_buffer_processor = AudioBufferProcessor(num_channels=2) - canonical = CanonicalMetricsService( - audio_buffer_processor=audio_buffer_processor, - aiohttp_session=session, - api_key=os.getenv("CANONICAL_API_KEY"), - call_id=str(uuid.uuid4()), - assistant="pipecat-chatbot", - assistant_speaks_first=True, - context=context, - ) - pipeline = Pipeline( - [ - transport.input(), # microphone - context_aggregator.user(), - llm, - tts, - transport.output(), - canonical, # uploads audio buffer to Canonical AI for metrics - audio_buffer_processor, # captures audio into a buffer - context_aggregator.assistant(), - ] - ) - - task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True)) - - @transport.event_handler("on_first_participant_joined") - async def on_first_participant_joined(transport, participant): - await audio_buffer_processor.start_recording() - await transport.capture_participant_transcription(participant["id"]) - await task.queue_frames([context_aggregator.user().get_context_frame()]) - - @transport.event_handler("on_participant_left") - async def on_participant_left(transport, participant, reason): - print(f"Participant left: {participant}") - await task.cancel() - - @transport.event_handler("on_call_state_updated") - async def on_call_state_updated(transport, state): - if state == "left": - # Here we don't want to cancel, we just want to finish sending - # whatever is queued, so we use an EndFrame(). - await task.queue_frame(EndFrame()) - - runner = PipelineRunner() - - await runner.run(task) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/canonical-metrics/env.example b/examples/canonical-metrics/env.example deleted file mode 100644 index 6b865401a..000000000 --- a/examples/canonical-metrics/env.example +++ /dev/null @@ -1,6 +0,0 @@ -DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev) -DAILY_API_KEY=7df... -OPENAI_API_KEY=sk-PL... -ELEVENLABS_API_KEY=aeb... -CANONICAL_API_KEY=can... -CANONICAL_API_URL= diff --git a/examples/canonical-metrics/requirements.txt b/examples/canonical-metrics/requirements.txt deleted file mode 100644 index 7e53edc6b..000000000 --- a/examples/canonical-metrics/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -python-dotenv -fastapi[all] -uvicorn -pipecat-ai[daily,openai,silero,elevenlabs,canonical] - diff --git a/examples/canonical-metrics/runner.py b/examples/canonical-metrics/runner.py deleted file mode 100644 index ad39a3ac4..000000000 --- a/examples/canonical-metrics/runner.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import argparse -import os - -import aiohttp - -from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper - - -async def configure(aiohttp_session: aiohttp.ClientSession): - parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample") - parser.add_argument( - "-u", "--url", type=str, required=False, help="URL of the Daily room to join" - ) - parser.add_argument( - "-k", - "--apikey", - type=str, - required=False, - help="Daily API Key (needed to create an owner token for the room)", - ) - - args, unknown = parser.parse_known_args() - - url = args.url or os.getenv("DAILY_SAMPLE_ROOM_URL") - key = args.apikey or os.getenv("DAILY_API_KEY") - - if not url: - raise Exception( - "No Daily room specified. use the -u/--url option from the command line, or set DAILY_SAMPLE_ROOM_URL in your environment to specify a Daily room URL." - ) - - if not key: - raise Exception( - "No Daily API key specified. use the -k/--apikey option from the command line, or set DAILY_API_KEY in your environment to specify a Daily API key, available from https://dashboard.daily.co/developers." - ) - - daily_rest_helper = DailyRESTHelper( - daily_api_key=key, - daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"), - aiohttp_session=aiohttp_session, - ) - - # Create a meeting token for the given room with an expiration 1 hour in - # the future. - expiry_time: float = 60 * 60 - - token = await daily_rest_helper.get_token(url, expiry_time) - - return (url, token) diff --git a/examples/canonical-metrics/server.py b/examples/canonical-metrics/server.py deleted file mode 100644 index a0f38854c..000000000 --- a/examples/canonical-metrics/server.py +++ /dev/null @@ -1,139 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import argparse -import os -import subprocess -from contextlib import asynccontextmanager - -import aiohttp -from dotenv import load_dotenv -from fastapi import FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, RedirectResponse - -from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams - -MAX_BOTS_PER_ROOM = 1 - -# Bot sub-process dict for status reporting and concurrency control -bot_procs = {} - -daily_helpers = {} - -load_dotenv(override=True) - - -def cleanup(): - # Clean up function, just to be extra safe - for entry in bot_procs.values(): - proc = entry[0] - proc.terminate() - proc.wait() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - aiohttp_session = aiohttp.ClientSession() - daily_helpers["rest"] = DailyRESTHelper( - daily_api_key=os.getenv("DAILY_API_KEY", ""), - daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"), - aiohttp_session=aiohttp_session, - ) - yield - await aiohttp_session.close() - cleanup() - - -app = FastAPI(lifespan=lifespan) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.get("/") -async def start_agent(request: Request): - print(f"!!! Creating room") - room = await daily_helpers["rest"].create_room(DailyRoomParams()) - print(f"!!! Room URL: {room.url}") - # Ensure the room property is present - if not room.url: - raise HTTPException( - status_code=500, - detail="Missing 'room' property in request data. Cannot start agent without a target room!", - ) - - # Check if there is already an existing process running in this room - num_bots_in_room = sum( - 1 for proc in bot_procs.values() if proc[1] == room.url and proc[0].poll() is None - ) - if num_bots_in_room >= MAX_BOTS_PER_ROOM: - raise HTTPException(status_code=500, detail=f"Max bot limited reach for room: {room.url}") - - # Get the token for the room - token = await daily_helpers["rest"].get_token(room.url) - - if not token: - raise HTTPException(status_code=500, detail=f"Failed to get token for room: {room.url}") - - # Spawn a new agent, and join the user session - # Note: this is mostly for demonstration purposes (refer to 'deployment' in README) - try: - proc = subprocess.Popen( - [f"python3 -m bot -u {room.url} -t {token}"], - shell=True, - bufsize=1, - cwd=os.path.dirname(os.path.abspath(__file__)), - ) - bot_procs[proc.pid] = (proc, room.url) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to start subprocess: {e}") - - return RedirectResponse(room.url) - - -@app.get("/status/{pid}") -def get_status(pid: int): - # Look up the subprocess - proc = bot_procs.get(pid) - - # If the subprocess doesn't exist, return an error - if not proc: - raise HTTPException(status_code=404, detail=f"Bot with process id: {pid} not found") - - # Check the status of the subprocess - if proc[0].poll() is None: - status = "running" - else: - status = "finished" - - return JSONResponse({"bot_id": pid, "status": status}) - - -if __name__ == "__main__": - import uvicorn - - default_host = os.getenv("HOST", "0.0.0.0") - default_port = int(os.getenv("FAST_API_PORT", "7860")) - - parser = argparse.ArgumentParser(description="Daily Storyteller FastAPI server") - parser.add_argument("--host", type=str, default=default_host, help="Host address") - parser.add_argument("--port", type=int, default=default_port, help="Port number") - parser.add_argument("--reload", action="store_true", help="Reload code on change") - - config = parser.parse_args() - - uvicorn.run( - "server:app", - host=config.host, - port=config.port, - reload=config.reload, - ) diff --git a/examples/foundational/07c-interruptible-deepgram-vad.py b/examples/foundational/07c-interruptible-deepgram-vad.py index a6d6ab4bb..945cdc447 100644 --- a/examples/foundational/07c-interruptible-deepgram-vad.py +++ b/examples/foundational/07c-interruptible-deepgram-vad.py @@ -47,7 +47,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac live_options=LiveOptions(vad_events=True, utterance_end_ms="1000"), ) - tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en") + tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en") llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) diff --git a/examples/foundational/07c-interruptible-deepgram.py b/examples/foundational/07c-interruptible-deepgram.py index 3e02d8d77..2a707da4a 100644 --- a/examples/foundational/07c-interruptible-deepgram.py +++ b/examples/foundational/07c-interruptible-deepgram.py @@ -39,7 +39,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) - tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en") + tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en") llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) diff --git a/examples/foundational/07m-interruptible-polly.py b/examples/foundational/07m-interruptible-aws.py similarity index 79% rename from examples/foundational/07m-interruptible-polly.py rename to examples/foundational/07m-interruptible-aws.py index 286fe5128..bbcfe7313 100644 --- a/examples/foundational/07m-interruptible-polly.py +++ b/examples/foundational/07m-interruptible-aws.py @@ -5,7 +5,6 @@ # import argparse -import os from dotenv import load_dotenv from loguru import logger @@ -15,9 +14,9 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.aws.tts import PollyTTSService -from pipecat.services.deepgram.stt import DeepgramSTTService -from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.aws.llm import AWSBedrockLLMService +from pipecat.services.aws.stt import AWSTranscribeSTTService +from pipecat.services.aws.tts import AWSPollyTTSService from pipecat.transports.base_transport import TransportParams from pipecat.transports.network.small_webrtc import SmallWebRTCTransport from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection @@ -37,17 +36,19 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac ), ) - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + stt = AWSTranscribeSTTService() - tts = PollyTTSService( - api_key=os.getenv("AWS_SECRET_ACCESS_KEY"), - aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), - region=os.getenv("AWS_REGION"), - voice_id="Amy", - params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"), + tts = AWSPollyTTSService( + region="us-west-2", # only specific regions support generative TTS + voice_id="Joanna", + params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"), ) - llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) + llm = AWSBedrockLLMService( + aws_region="us-west-2", + model="us.anthropic.claude-3-5-haiku-20241022-v1:0", + params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"), + ) messages = [ { @@ -85,7 +86,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac async def on_client_connected(transport, client): logger.info(f"Client connected") # Kick off the conversation. - messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + messages.append({"role": "user", "content": "Please introduce yourself to the user."}) await task.queue_frames([context_aggregator.user().get_context_frame()]) @transport.event_handler("on_client_disconnected") diff --git a/examples/foundational/14r-function-calling-aws.py b/examples/foundational/14r-function-calling-aws.py new file mode 100644 index 000000000..cf4859576 --- /dev/null +++ b/examples/foundational/14r-function-calling-aws.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.aws.llm import AWSBedrockLLMService +from pipecat.services.aws.stt import AWSTranscribeSTTService +from pipecat.services.aws.tts import AWSPollyTTSService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.transports.base_transport import TransportParams +from pipecat.transports.network.small_webrtc import SmallWebRTCTransport +from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection + +load_dotenv(override=True) + + +async def fetch_weather_from_api(params: FunctionCallParams): + await params.result_callback({"conditions": "nice", "temperature": "75"}) + + +async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace): + logger.info(f"Starting bot") + + transport = SmallWebRTCTransport( + webrtc_connection=webrtc_connection, + params=TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + stt = AWSTranscribeSTTService() + + tts = AWSPollyTTSService( + region="us-west-2", # only specific regions support generative TTS + voice_id="Joanna", + params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"), + ) + + llm = AWSBedrockLLMService( + aws_region="us-west-2", + model="us.anthropic.claude-3-5-haiku-20241022-v1:0", + params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"), + ) + + # You can also register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_function("get_current_weather", fetch_weather_from_api) + + weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + required=["location", "format"], + ) + tools = ToolsSchema(standard_tools=[weather_function]) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append({"role": "user", "content": "Please introduce yourself to the user."}) + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + + @transport.event_handler("on_client_closed") + async def on_client_closed(transport, client): + logger.info(f"Client closed connection") + await task.cancel() + + runner = PipelineRunner(handle_sigint=False) + + await runner.run(task) + + +if __name__ == "__main__": + from run import main + + main() diff --git a/examples/foundational/20e-persistent-context-aws-nova-sonic.py b/examples/foundational/20e-persistent-context-aws-nova-sonic.py new file mode 100644 index 000000000..1519f1c53 --- /dev/null +++ b/examples/foundational/20e-persistent-context-aws-nova-sonic.py @@ -0,0 +1,267 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import asyncio +import glob +import json +import os +from datetime import datetime + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.aws_nova_sonic.aws import AWSNovaSonicLLMService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.transports.base_transport import TransportParams +from pipecat.transports.network.small_webrtc import SmallWebRTCTransport +from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection + +load_dotenv(override=True) + +BASE_FILENAME = "/tmp/pipecat_conversation_" + + +async def fetch_weather_from_api(params: FunctionCallParams): + temperature = 75 if params.arguments["format"] == "fahrenheit" else 24 + await params.result_callback( + { + "conditions": "nice", + "temperature": temperature, + "format": params.arguments["format"], + "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), + } + ) + + +async def get_saved_conversation_filenames(params: FunctionCallParams): + # Construct the full pattern including the BASE_FILENAME + full_pattern = f"{BASE_FILENAME}*.json" + + # Use glob to find all matching files + matching_files = glob.glob(full_pattern) + logger.debug(f"matching files: {matching_files}") + + await params.result_callback({"filenames": matching_files}) + + +# async def get_saved_conversation_filenames( +# function_name, tool_call_id, args, llm, context, result_callback +# ): +# pattern = re.compile(re.escape(BASE_FILENAME) + "\\d{8}_\\d{6}\\.json$") +# matching_files = [] + +# for filename in os.listdir("."): +# if pattern.match(filename): +# matching_files.append(filename) + +# await result_callback({"filenames": matching_files}) + + +async def save_conversation(params: FunctionCallParams): + timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + filename = f"{BASE_FILENAME}{timestamp}.json" + try: + with open(filename, "w") as file: + messages = params.context.get_messages_for_persistent_storage() + # remove the last few messages. in reverse order, they are: + # - the in progress save tool call + # - the invocation of the save tool call + # - the user ask to save (which may encompass one or more messages) + # the simplest thing to do is to pop messages until the last one is an assistant + # response + while messages and not ( + messages[-1].get("role") == "assistant" and "content" in messages[-1] + ): + messages.pop() + if messages: # we never expect this to be empty + logger.debug( + f"writing conversation to {filename}\n{json.dumps(messages, indent=4)}" + ) + json.dump(messages, file, indent=2) + await params.result_callback({"success": True}) + except Exception as e: + await params.result_callback({"success": False, "error": str(e)}) + + +async def load_conversation(params: FunctionCallParams): + async def _reset(): + filename = params.arguments["filename"] + logger.debug(f"loading conversation from {filename}") + try: + with open(filename, "r") as file: + messages = json.load(file) + messages.append( + { + "role": "user", + "content": f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}", + } + ) + params.context.set_messages(messages) + await params.llm.reset_conversation() + await params.llm.trigger_assistant_response() + except Exception as e: + await params.result_callback({"success": False, "error": str(e)}) + + asyncio.create_task(_reset()) + + +get_current_weather_tool = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + required=["location", "format"], +) + +save_conversation_tool = FunctionSchema( + name="save_conversation", + description="Save the current conversation. Use this function to persist the current conversation to external storage.", + properties={}, + required=[], +) + +get_saved_conversation_filenames_tool = FunctionSchema( + name="get_saved_conversation_filenames", + description="Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.", + properties={}, + required=[], +) + +load_conversation_tool = FunctionSchema( + name="load_conversation", + description="Load a conversation history. Use this function to load a conversation history into the current session.", + properties={ + "filename": { + "type": "string", + "description": "The filename of the conversation history to load.", + } + }, + required=["filename"], +) + +tools = ToolsSchema( + standard_tools=[ + get_current_weather_tool, + save_conversation_tool, + get_saved_conversation_filenames_tool, + load_conversation_tool, + ] +) + + +async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace): + logger.info(f"Starting bot") + + transport = SmallWebRTCTransport( + webrtc_connection=webrtc_connection, + params=TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)), + ), + ) + + # Specify initial system instruction. + # HACK: note that, for now, we need to inject a special bit of text into this instruction to + # allow the first assistant response to be programmatically triggered (which happens in the + # on_client_connected handler, below) + system_instruction = ( + "You are a friendly assistant. The user and you will engage in a spoken dialog exchanging " + "the transcripts of a natural real-time conversation. Keep your responses short, generally " + "two or three sentences for chatty scenarios. " + f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}" + ) + + llm = AWSNovaSonicLLMService( + secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region + voice_id="tiffany", # matthew, tiffany, amy + # you could choose to pass instruction here rather than via context + # system_instruction=system_instruction, + # you could choose to pass tools here rather than via context + # tools=tools + ) + + llm.register_function("get_current_weather", fetch_weather_from_api) + llm.register_function("save_conversation", save_conversation) + llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames) + llm.register_function("load_conversation", load_conversation) + + context = OpenAILLMContext( + messages=[ + {"role": "system", "content": f"{system_instruction}"}, + ], + tools=tools, + ) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + context_aggregator.user(), + llm, # LLM + transport.output(), # Transport bot output + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + # HACK: for now, we need this special way of triggering the first assistant response in AWS + # Nova Sonic. Note that this trigger requires a special corresponding bit of text in the + # system instruction. In the future, simply queueing the context frame should be sufficient. + await llm.trigger_assistant_response() + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + + @transport.event_handler("on_client_closed") + async def on_client_closed(transport, client): + logger.info(f"Client closed connection") + await task.cancel() + + runner = PipelineRunner(handle_sigint=False) + + await runner.run(task) + + +if __name__ == "__main__": + from run import main + + main() diff --git a/examples/foundational/30-observer.py b/examples/foundational/30-observer.py index d8c2ec100..c9cd08aee 100644 --- a/examples/foundational/30-observer.py +++ b/examples/foundational/30-observer.py @@ -14,19 +14,26 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, - Frame, + EndFrame, StartInterruptionFrame, + TTSTextFrame, + UserStartedSpeakingFrame, ) -from pipecat.observers.base_observer import BaseObserver +from pipecat.observers.base_observer import BaseObserver, FramePushed +from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint from pipecat.observers.loggers.llm_log_observer import LLMLogObserver from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, +) +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.cartesia.tts import CartesiaTTSService from pipecat.services.deepgram.stt import DeepgramSTTService from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import TransportParams from pipecat.transports.network.small_webrtc import SmallWebRTCTransport from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection @@ -34,7 +41,7 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection load_dotenv(override=True) -class DebugObserver(BaseObserver): +class CustomObserver(BaseObserver): """Observer to log interruptions and bot speaking events to the console. Logs all frame instances of: @@ -46,21 +53,20 @@ class DebugObserver(BaseObserver): Log format: [EVENT TYPE]: [source processor] → [destination processor] at [timestamp]s """ - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): + async def on_push_frame(self, data: FramePushed): + src = data.source + dst = data.destination + frame = data.frame + direction = data.direction + timestamp = data.timestamp + # Convert timestamp to seconds for readability time_sec = timestamp / 1_000_000_000 # Create direction arrow arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←" - if isinstance(frame, StartInterruptionFrame): + if isinstance(frame, StartInterruptionFrame) and isinstance(src, BaseOutputTransport): logger.info(f"⚡ INTERRUPTION START: {src} {arrow} {dst} at {time_sec:.2f}s") elif isinstance(frame, BotStartedSpeakingFrame): logger.info(f"🤖 BOT START SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s") @@ -119,7 +125,17 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac enable_usage_metrics=True, report_only_initial_ttfb=True, ), - observers=[DebugObserver(), LLMLogObserver()], + observers=[ + CustomObserver(), + LLMLogObserver(), + DebugLogObserver( + frame_types={ + TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION), + UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE), + EndFrame: None, + } + ), + ], ) @transport.event_handler("on_client_connected") diff --git a/examples/foundational/39-aws-nova-sonic.py b/examples/foundational/39-aws-nova-sonic.py new file mode 100644 index 000000000..4ed533e18 --- /dev/null +++ b/examples/foundational/39-aws-nova-sonic.py @@ -0,0 +1,173 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import os +from datetime import datetime + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.aws_nova_sonic import AWSNovaSonicLLMService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.transports.base_transport import TransportParams +from pipecat.transports.network.small_webrtc import SmallWebRTCTransport +from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection + +# Load environment variables +load_dotenv(override=True) + + +async def fetch_weather_from_api(params: FunctionCallParams): + temperature = 75 if params.arguments["format"] == "fahrenheit" else 24 + await params.result_callback( + { + "conditions": "nice", + "temperature": temperature, + "format": params.arguments["format"], + "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), + } + ) + + +weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + required=["location", "format"], +) + +# Create tools schema +tools = ToolsSchema(standard_tools=[weather_function]) + + +async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace): + logger.info(f"Starting bot") + + # Initialize the SmallWebRTCTransport with the connection + transport = SmallWebRTCTransport( + webrtc_connection=webrtc_connection, + params=TransportParams( + audio_in_enabled=True, + audio_in_sample_rate=16000, + audio_out_enabled=True, + camera_in_enabled=False, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)), + ), + ) + + # Specify initial system instruction. + # HACK: note that, for now, we need to inject a special bit of text into this instruction to + # allow the first assistant response to be programmatically triggered (which happens in the + # on_client_connected handler, below) + system_instruction = ( + "You are a friendly assistant. The user and you will engage in a spoken dialog exchanging " + "the transcripts of a natural real-time conversation. Keep your responses short, generally " + "two or three sentences for chatty scenarios. " + f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}" + ) + + # Create the AWS Nova Sonic LLM service + llm = AWSNovaSonicLLMService( + secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region + voice_id="tiffany", # matthew, tiffany, amy + # you could choose to pass instruction here rather than via context + # system_instruction=system_instruction + # you could choose to pass tools here rather than via context + # tools=tools + ) + + # Register function for function calls + # you can either register a single function for all function calls, or specific functions + # llm.register_function(None, fetch_weather_from_api) + llm.register_function("get_current_weather", fetch_weather_from_api) + + # Set up context and context management. + # AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to + # what's expected by Nova Sonic. + context = OpenAILLMContext( + messages=[ + {"role": "system", "content": f"{system_instruction}"}, + { + "role": "user", + "content": "Tell me a fun fact!", + }, + ], + tools=tools, + ) + context_aggregator = llm.create_context_aggregator(context) + + # Build the pipeline + pipeline = Pipeline( + [ + transport.input(), + context_aggregator.user(), + llm, + transport.output(), + context_aggregator.assistant(), + ] + ) + + # Configure the pipeline task + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + # Handle client connection event + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + # HACK: for now, we need this special way of triggering the first assistant response in AWS + # Nova Sonic. Note that this trigger requires a special corresponding bit of text in the + # system instruction. In the future, simply queueing the context frame should be sufficient. + await llm.trigger_assistant_response() + + # Handle client disconnection events + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + + @transport.event_handler("on_client_closed") + async def on_client_closed(transport, client): + logger.info(f"Client closed connection") + await task.cancel() + + # Run the pipeline + runner = PipelineRunner(handle_sigint=False) + await runner.run(task) + + +if __name__ == "__main__": + from run import main + + main() diff --git a/examples/moondream-chatbot/server.py b/examples/moondream-chatbot/server.py index bb322ff2e..9597bdc9a 100644 --- a/examples/moondream-chatbot/server.py +++ b/examples/moondream-chatbot/server.py @@ -10,12 +10,16 @@ import subprocess from contextlib import asynccontextmanager import aiohttp +from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams +# Load environment variables from .env file +load_dotenv(override=True) + MAX_BOTS_PER_ROOM = 1 # Bot sub-process dict for status reporting and concurrency control diff --git a/examples/patient-intake/server.py b/examples/patient-intake/server.py index 347b17dbd..10ccfb3b7 100644 --- a/examples/patient-intake/server.py +++ b/examples/patient-intake/server.py @@ -10,12 +10,16 @@ import subprocess from contextlib import asynccontextmanager import aiohttp +from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams +# Load environment variables from .env file +load_dotenv(override=True) + MAX_BOTS_PER_ROOM = 1 # Bot sub-process dict for status reporting and concurrency control diff --git a/pyproject.toml b/pyproject.toml index ecddb0902..3b34569c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,13 +41,13 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.49.0" ] assemblyai = [ "assemblyai~=0.37.0" ] -aws = [ "boto3~=1.37.16" ] +aws = [ "boto3~=1.37.16", "websockets~=13.1" ] +aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ] azure = [ "azure-cognitiveservices-speech~=1.42.0"] -canonical = [ "aiofiles~=24.1.0" ] cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ] cerebras = [] deepseek = [] -daily = [ "daily-python~=0.18.1" ] +daily = [ "daily-python~=0.18.2" ] deepgram = [ "deepgram-sdk~=3.8.0" ] elevenlabs = [ "websockets~=13.1" ] fal = [ "fal-client~=0.5.9" ] @@ -97,6 +97,7 @@ where = ["src"] [tool.setuptools.package-data] "pipecat" = ["py.typed"] +"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"] [tool.pytest.ini_options] addopts = "--verbose" diff --git a/src/pipecat/adapters/services/anthropic_adapter.py b/src/pipecat/adapters/services/anthropic_adapter.py index a699469d3..23197d3a8 100644 --- a/src/pipecat/adapters/services/anthropic_adapter.py +++ b/src/pipecat/adapters/services/anthropic_adapter.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from pipecat.adapters.base_llm_adapter import BaseLLMAdapter from pipecat.adapters.schemas.function_schema import FunctionSchema diff --git a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py new file mode 100644 index 000000000..dc7eef92d --- /dev/null +++ b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# +import json +from typing import Any, Dict, List + +from pipecat.adapters.base_llm_adapter import BaseLLMAdapter +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema + + +class AWSNovaSonicLLMAdapter(BaseLLMAdapter): + @staticmethod + def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]: + return { + "toolSpec": { + "name": function.name, + "description": function.description, + "inputSchema": { + "json": json.dumps( + { + "type": "object", + "properties": function.properties, + "required": function.required, + } + ) + }, + } + } + + def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: + """Converts function schemas to AWS Nova Sonic function-calling format. + + :return: AWS Nova Sonic formatted function call definition. + """ + + functions_schema = tools_schema.standard_tools + return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema] diff --git a/src/pipecat/adapters/services/bedrock_adapter.py b/src/pipecat/adapters/services/bedrock_adapter.py new file mode 100644 index 000000000..113a6938d --- /dev/null +++ b/src/pipecat/adapters/services/bedrock_adapter.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Any, Dict, List + +from pipecat.adapters.base_llm_adapter import BaseLLMAdapter +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema + + +class AWSBedrockLLMAdapter(BaseLLMAdapter): + @staticmethod + def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]: + return { + "toolSpec": { + "name": function.name, + "description": function.description, + "inputSchema": { + "json": { + "type": "object", + "properties": function.properties, + "required": function.required, + }, + }, + } + } + + def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: + """Converts function schemas to Bedrock's function-calling format. + + :return: Bedrock formatted function call definition. + """ + + functions_schema = tools_schema.standard_tools + return [self._to_bedrock_function_format(func) for func in functions_schema] diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 05f5b666d..8d3f38459 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -715,9 +715,10 @@ class UserImageRequestFrame(SystemFrame): context: Optional[Any] = None function_name: Optional[str] = None tool_call_id: Optional[str] = None + video_source: Optional[str] = None def __str__(self): - return f"{self.name}(user: {self.user_id}, function: {self.function_name}, request: {self.tool_call_id})" + return f"{self.name}(user: {self.user_id}, video_source: {self.video_source}, function: {self.function_name}, request: {self.tool_call_id})" @dataclass diff --git a/src/pipecat/observers/base_observer.py b/src/pipecat/observers/base_observer.py index 46f746946..f1a0c2a1b 100644 --- a/src/pipecat/observers/base_observer.py +++ b/src/pipecat/observers/base_observer.py @@ -5,9 +5,38 @@ # from abc import ABC, abstractmethod +from dataclasses import dataclass + +from typing_extensions import TYPE_CHECKING from pipecat.frames.frames import Frame -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +if TYPE_CHECKING: + from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +@dataclass +class FramePushed: + """Represents an event where a frame is pushed from one processor to another + within the pipeline. + + This data structure is typically used by observers to track the flow of + frames through the pipeline for logging, debugging, or analytics purposes. + + Attributes: + source (FrameProcessor): The processor sending the frame. + destination (FrameProcessor): The processor receiving the frame. + frame (Frame): The frame being transferred. + direction (FrameDirection): The direction of the transfer (e.g., downstream or upstream). + timestamp (int): The time when the frame was pushed, based on the pipeline clock. + + """ + + source: "FrameProcessor" + destination: "FrameProcessor" + frame: Frame + direction: "FrameDirection" + timestamp: int class BaseObserver(ABC): @@ -19,26 +48,15 @@ class BaseObserver(ABC): """ @abstractmethod - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): - """Abstract method to handle the event when a frame is pushed from one - processor to another. + async def on_push_frame(self, data: FramePushed): + """Handle the event when a frame is pushed from one processor to another. + + This method should be implemented by subclasses to define specific + behavior (e.g., logging, monitoring, debugging) when a frame is + transferred through the pipeline. Args: - src (FrameProcessor): The source frame processor that is sending the frame. - dst (FrameProcessor): The destination frame processor that will receive the frame. - frame (Frame): The frame being transferred between processors. - direction (FrameDirection): The direction of the frame transfer. - timestamp (int): The timestamp when the frame was pushed (based on the pipeline clock). - - This method should be implemented by subclasses to define specific behavior - when a frame is pushed. + data (FramePushed): The event data containing details about the frame transfer. """ pass diff --git a/src/pipecat/observers/loggers/debug_log_observer.py b/src/pipecat/observers/loggers/debug_log_observer.py new file mode 100644 index 000000000..575a31683 --- /dev/null +++ b/src/pipecat/observers/loggers/debug_log_observer.py @@ -0,0 +1,218 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dataclasses import fields, is_dataclass +from enum import Enum, auto +from typing import Dict, Optional, Set, Tuple, Type, Union + +from loguru import logger + +from pipecat.frames.frames import Frame +from pipecat.observers.base_observer import BaseObserver, FramePushed +from pipecat.processors.frame_processor import FrameDirection + + +class FrameEndpoint(Enum): + """Specifies which endpoint (source or destination) to filter on.""" + + SOURCE = auto() + DESTINATION = auto() + + +class DebugLogObserver(BaseObserver): + """Observer that logs frame activity with detailed content to the console. + + Automatically extracts and formats data from any frame type, making it useful + for debugging pipeline behavior without needing frame-specific observers. + + Args: + frame_types: Optional tuple of frame types to log, or a dict with frame type + filters. If None, logs all frame types. + exclude_fields: Optional set of field names to exclude from logging. + + Examples: + Log all frames from all services: + ```python + observers = DebugLogObserver() + ``` + + Log specific frame types from any source/destination: + ```python + from pipecat.frames.frames import TranscriptionFrame, InterimTranscriptionFrame + observers = DebugLogObserver(frame_types=(TranscriptionFrame, InterimTranscriptionFrame)) + ``` + + Log frames with specific source/destination filters: + ```python + from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame + from pipecat.transports.base_output_transport import BaseOutputTransport + from pipecat.services.stt_service import STTService + + observers = DebugLogObserver(frame_types={ + # Only log StartInterruptionFrame when source is BaseOutputTransport + StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE), + + # Only log UserStartedSpeakingFrame when destination is STTService + UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION), + + # Log LLMTextFrame regardless of source or destination type + LLMTextFrame: None + }) + ``` + """ + + def __init__( + self, + frame_types: Optional[ + Union[Tuple[Type[Frame], ...], Dict[Type[Frame], Optional[Tuple[Type, FrameEndpoint]]]] + ] = None, + exclude_fields: Optional[Set[str]] = None, + ): + """Initialize the debug log observer. + + Args: + frame_types: Tuple of frame types to log, or a dict mapping frame types to + filter configurations. Filter configs can be: + - None to log all instances of the frame type + - A tuple of (service_type, endpoint) to filter on a specific service + and endpoint (SOURCE or DESTINATION) + If None is provided instead of a tuple/dict, log all frames. + exclude_fields: Set of field names to exclude from logging. If None, only binary + data fields are excluded. + """ + # Process frame filters + self.frame_filters = {} + + if frame_types is not None: + if isinstance(frame_types, tuple): + # Tuple of frame types - log all instances + self.frame_filters = {frame_type: None for frame_type in frame_types} + else: + # Dict of frame types with filters + self.frame_filters = frame_types + + # By default, exclude binary data fields that would clutter logs + self.exclude_fields = ( + exclude_fields + if exclude_fields is not None + else { + "audio", # Skip binary audio data + "image", # Skip binary image data + "images", # Skip lists of images + } + ) + + def _format_value(self, value): + """Format a value for logging. + + Args: + value: The value to format. + + Returns: + str: A string representation of the value suitable for logging. + """ + if value is None: + return "None" + elif isinstance(value, str): + return f"{value!r}" + elif isinstance(value, (list, tuple)): + if len(value) == 0: + return "[]" + if isinstance(value[0], dict) and len(value) > 3: + # For message lists, just show count + return f"{len(value)} items" + return str(value) + elif isinstance(value, (bytes, bytearray)): + return f"{len(value)} bytes" + elif hasattr(value, "get_messages_for_logging") and callable( + getattr(value, "get_messages_for_logging") + ): + # Special case for OpenAI context + return f"{value.__class__.__name__} with messages: {value.get_messages_for_logging()}" + else: + return str(value) + + def _should_log_frame(self, frame, src, dst): + """Determine if a frame should be logged based on filters. + + Args: + frame: The frame being processed + src: The source component + dst: The destination component + + Returns: + bool: True if the frame should be logged, False otherwise + """ + # If no filters, log all frames + if not self.frame_filters: + return True + + # Check if this frame type is in our filters + for frame_type, filter_config in self.frame_filters.items(): + if isinstance(frame, frame_type): + # If filter is None, log all instances of this frame type + if filter_config is None: + return True + + # Otherwise, check the specific filter + service_type, endpoint = filter_config + + if endpoint == FrameEndpoint.SOURCE: + return isinstance(src, service_type) + elif endpoint == FrameEndpoint.DESTINATION: + return isinstance(dst, service_type) + + return False + + async def on_push_frame(self, data: FramePushed): + """Process a frame being pushed into the pipeline. + + Logs frame details to the console with all relevant fields and values. + + Args: + data: Event data containing the frame, source, destination, direction, and timestamp. + """ + src = data.source + dst = data.destination + frame = data.frame + direction = data.direction + timestamp = data.timestamp + + # Check if we should log this frame + if not self._should_log_frame(frame, src, dst): + return + + # Format direction arrow + arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←" + + time_sec = timestamp / 1_000_000_000 + class_name = frame.__class__.__name__ + + # Build frame representation + frame_details = [] + + # If dataclass, extract fields + if is_dataclass(frame): + for field in fields(frame): + if field.name in self.exclude_fields: + continue + + value = getattr(frame, field.name) + if value is None: + continue + + formatted_value = self._format_value(value) + frame_details.append(f"{field.name}: {formatted_value}") + + # Format the message + if frame_details: + details = ", ".join(frame_details) + message = f"{class_name} {details} at {time_sec:.2f}s" + else: + message = f"{class_name} at {time_sec:.2f}s" + + # Log the message + logger.debug(f"{src} {arrow} {dst}: {message}") diff --git a/src/pipecat/observers/loggers/llm_log_observer.py b/src/pipecat/observers/loggers/llm_log_observer.py index dd270abf5..a6675b5c0 100644 --- a/src/pipecat/observers/loggers/llm_log_observer.py +++ b/src/pipecat/observers/loggers/llm_log_observer.py @@ -7,7 +7,6 @@ from loguru import logger from pipecat.frames.frames import ( - Frame, FunctionCallInProgressFrame, FunctionCallResultFrame, LLMFullResponseEndFrame, @@ -15,9 +14,9 @@ from pipecat.frames.frames import ( LLMMessagesFrame, LLMTextFrame, ) -from pipecat.observers.base_observer import BaseObserver +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import LLMService @@ -38,14 +37,13 @@ class LLMLogObserver(BaseObserver): """ - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): + async def on_push_frame(self, data: FramePushed): + src = data.source + dst = data.destination + frame = data.frame + direction = data.direction + timestamp = data.timestamp + if not isinstance(src, LLMService) and not isinstance(dst, LLMService): return diff --git a/src/pipecat/observers/loggers/transcription_log_observer.py b/src/pipecat/observers/loggers/transcription_log_observer.py index 4547ee54f..8ca1d9c9b 100644 --- a/src/pipecat/observers/loggers/transcription_log_observer.py +++ b/src/pipecat/observers/loggers/transcription_log_observer.py @@ -7,12 +7,10 @@ from loguru import logger from pipecat.frames.frames import ( - Frame, InterimTranscriptionFrame, TranscriptionFrame, ) -from pipecat.observers.base_observer import BaseObserver -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.services.stt_service import STTService @@ -29,14 +27,11 @@ class TranscriptionLogObserver(BaseObserver): """ - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): + async def on_push_frame(self, data: FramePushed): + src = data.source + frame = data.frame + timestamp = data.timestamp + if not isinstance(src, STTService): return diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 8279373cb..c40173899 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -286,12 +286,7 @@ class PipelineTask(BaseTask): async def cancel(self): """Stops the running pipeline immediately.""" logger.debug(f"Canceling pipeline task {self}") - # Make sure everything is cleaned up downstream. This is sent - # out-of-band from the main streaming task which is what we want since - # we want to cancel right away. - await self._source.push_frame(CancelFrame()) - # Only cancel the push task. Everything else will be cancelled in run(). - await self._task_manager.cancel_task(self._process_push_task) + await self._cancel() async def run(self): """Starts and manages the pipeline execution until completion or cancellation.""" @@ -309,11 +304,17 @@ class PipelineTask(BaseTask): # well, because you get a CancelledError in every place you are # awaiting a task. pass - await self._cancel_tasks() - await self._cleanup(cleanup_pipeline) - if self._check_dangling_tasks: - self._print_dangling_tasks() - self._finished = True + finally: + # It's possibe that we get an asyncio.CancelledError from the + # outside, if so we need to make sure everything gets cancelled + # properly. + if cleanup_pipeline: + await self._cancel() + await self._cancel_tasks() + await self._cleanup(cleanup_pipeline) + if self._check_dangling_tasks: + self._print_dangling_tasks() + self._finished = True async def queue_frame(self, frame: Frame): """Queue a single frame to be pushed down the pipeline. @@ -336,6 +337,14 @@ class PipelineTask(BaseTask): for frame in frames: await self.queue_frame(frame) + async def _cancel(self): + # Make sure everything is cleaned up downstream. This is sent + # out-of-band from the main streaming task which is what we want since + # we want to cancel right away. + await self._source.push_frame(CancelFrame()) + # Only cancel the push task. Everything else will be cancelled in run(). + await self._task_manager.cancel_task(self._process_push_task) + async def _create_tasks(self): self._process_up_task = self._task_manager.create_task( self._process_up_queue(), f"{self}::_process_up_queue" diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index dd805032c..252708f8c 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -5,13 +5,12 @@ # import asyncio +import inspect from typing import List from attr import dataclass -from pipecat.frames.frames import Frame -from pipecat.observers.base_observer import BaseObserver -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.utils.asyncio import BaseTaskManager @@ -27,20 +26,6 @@ class Proxy: observer: BaseObserver -@dataclass -class ObserverData: - """This is the data we receive from the main observer and that we put into a - proxy queue for later processing. - - """ - - src: FrameProcessor - dst: FrameProcessor - frame: Frame - direction: FrameDirection - timestamp: int - - class TaskObserver(BaseObserver): """This is a pipeline frame observer that is meant to be used as a proxy to the user provided observers. That is, this is the observer that should be @@ -68,20 +53,9 @@ class TaskObserver(BaseObserver): for proxy in self._proxies: await self._task_manager.cancel_task(proxy.task) - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): + async def on_push_frame(self, data: FramePushed): for proxy in self._proxies: - await proxy.queue.put( - ObserverData( - src=src, dst=dst, frame=frame, direction=direction, timestamp=timestamp - ) - ) + await proxy.queue.put(data) def _create_proxies(self, observers) -> List[Proxy]: proxies = [] @@ -96,8 +70,26 @@ class TaskObserver(BaseObserver): return proxies async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver): + warning_reported = False while True: data = await queue.get() - await observer.on_push_frame( - data.src, data.dst, data.frame, data.direction, data.timestamp - ) + + signature = inspect.signature(observer.on_push_frame) + if len(signature.parameters) > 1: + if not warning_reported: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Observer `on_push_frame(source, destination, frame, direction, timestamp)` is deprecated, us `on_push_frame(data: FramePushed)` instead.", + DeprecationWarning, + ) + warning_reported = True + await observer.on_push_frame( + data.src, data.dst, data.frame, data.direction, data.timestamp + ) + else: + await observer.on_push_frame(data) + + queue.task_done() diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 590698e7f..97cc24378 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -21,6 +21,7 @@ from pipecat.frames.frames import ( SystemFrame, ) from pipecat.metrics.metrics import LLMTokenUsage, MetricsData +from pipecat.observers.base_observer import FramePushed from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics from pipecat.utils.asyncio import BaseTaskManager from pipecat.utils.base_object import BaseObject @@ -294,17 +295,28 @@ class FrameProcessor(BaseObject): timestamp = self._clock.get_time() if self._clock else 0 if direction == FrameDirection.DOWNSTREAM and self._next: logger.trace(f"Pushing {frame} from {self} to {self._next}") + if self._observer: - await self._observer.on_push_frame( - self, self._next, frame, direction, timestamp + data = FramePushed( + source=self, + destination=self._next, + frame=frame, + direction=direction, + timestamp=timestamp, ) + await self._observer.on_push_frame(data) await self._next.queue_frame(frame, direction) elif direction == FrameDirection.UPSTREAM and self._prev: logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}") if self._observer: - await self._observer.on_push_frame( - self, self._prev, frame, direction, timestamp + data = FramePushed( + source=self, + destination=self._prev, + frame=frame, + direction=direction, + timestamp=timestamp, ) + await self._observer.on_push_frame(data) await self._prev.queue_frame(frame, direction) except Exception as e: logger.exception(f"Uncaught exception in {self}: {e}") diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 55e91d7ff..909dd15b7 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -55,7 +55,7 @@ from pipecat.metrics.metrics import ( TTFBMetricsData, TTSUsageMetricsData, ) -from pipecat.observers.base_observer import BaseObserver +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -254,7 +254,7 @@ class RTVIBotReady(BaseModel): class RTVILLMFunctionCallMessageData(BaseModel): function_name: str tool_call_id: str - arguments: Mapping[str, Any] + args: Mapping[str, Any] class RTVILLMFunctionCallMessage(BaseModel): @@ -445,14 +445,7 @@ class RTVIObserver(BaseObserver): self._frames_seen = set() rtvi.set_errors_enabled(self._params.errors_enabled) - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): + async def on_push_frame(self, data: FramePushed): """Process a frame being pushed through the pipeline. Args: @@ -462,6 +455,10 @@ class RTVIObserver(BaseObserver): direction: Direction of frame flow in pipeline timestamp: Time when frame was pushed """ + src = data.source + frame = data.frame + direction = data.direction + # If we have already seen this frame, let's skip it. if frame.id in self._frames_seen: return @@ -703,7 +700,7 @@ class RTVIProcessor(FrameProcessor): fn = RTVILLMFunctionCallMessageData( function_name=params.function_name, tool_call_id=params.tool_call_id, - arguments=params.arguments, + args=params.arguments, ) message = RTVILLMFunctionCallMessage(data=fn) await self._push_transport_message(message, exclude_none=False) diff --git a/src/pipecat/services/anthropic/llm.py b/src/pipecat/services/anthropic/llm.py index 277e29f83..eba6a5041 100644 --- a/src/pipecat/services/anthropic/llm.py +++ b/src/pipecat/services/anthropic/llm.py @@ -250,14 +250,24 @@ class AnthropicLLMService(LLMService): if hasattr(event.message.usage, "output_tokens") else 0 ) - if hasattr(event.message.usage, "cache_creation_input_tokens"): - cache_creation_input_tokens += ( - event.message.usage.cache_creation_input_tokens + cache_creation_input_tokens += ( + event.message.usage.cache_creation_input_tokens + if ( + hasattr(event.message.usage, "cache_creation_input_tokens") + and event.message.usage.cache_creation_input_tokens is not None ) - logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}") - if hasattr(event.message.usage, "cache_read_input_tokens"): - cache_read_input_tokens += event.message.usage.cache_read_input_tokens - logger.debug(f"Cache read input tokens: {cache_read_input_tokens}") + else 0 + ) + logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}") + cache_read_input_tokens += ( + event.message.usage.cache_read_input_tokens + if ( + hasattr(event.message.usage, "cache_read_input_tokens") + and event.message.usage.cache_read_input_tokens is not None + ) + else 0 + ) + logger.debug(f"Cache read input tokens: {cache_read_input_tokens}") total_input_tokens = ( prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens ) diff --git a/src/pipecat/services/aws/__init__.py b/src/pipecat/services/aws/__init__.py index b36c88499..b1f157bd3 100644 --- a/src/pipecat/services/aws/__init__.py +++ b/src/pipecat/services/aws/__init__.py @@ -8,6 +8,8 @@ import sys from pipecat.services import DeprecatedModuleProxy +from .llm import * +from .stt import * from .tts import * -sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts") +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]") diff --git a/src/pipecat/services/aws/llm.py b/src/pipecat/services/aws/llm.py new file mode 100644 index 000000000..921d3c790 --- /dev/null +++ b/src/pipecat/services/aws/llm.py @@ -0,0 +1,785 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import base64 +import copy +import io +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger +from PIL import Image +from pydantic import BaseModel, Field + +from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter +from pipecat.frames.frames import ( + Frame, + FunctionCallCancelFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMTextFrame, + LLMUpdateSettingsFrame, + UserImageRawFrame, + VisionImageRawFrame, +) +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMAssistantContextAggregator, + LLMUserAggregatorParams, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.llm_service import LLMService + +try: + import boto3 + import httpx + from botocore.config import Config +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +@dataclass +class AWSBedrockContextAggregatorPair: + _user: "AWSBedrockUserContextAggregator" + _assistant: "AWSBedrockAssistantContextAggregator" + + def user(self) -> "AWSBedrockUserContextAggregator": + return self._user + + def assistant(self) -> "AWSBedrockAssistantContextAggregator": + return self._assistant + + +class AWSBedrockLLMContext(OpenAILLMContext): + def __init__( + self, + messages: Optional[List[dict]] = None, + tools: Optional[List[dict]] = None, + tool_choice: Optional[dict] = None, + *, + system: Optional[str] = None, + ): + super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) + self.system = system + + @staticmethod + def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext": + logger.debug(f"Upgrading to AWS Bedrock: {obj}") + if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext): + obj.__class__ = AWSBedrockLLMContext + obj._restructure_from_openai_messages() + else: + obj._restructure_from_bedrock_messages() + return obj + + @classmethod + def from_openai_context(cls, openai_context: OpenAILLMContext): + self = cls( + messages=openai_context.messages, + tools=openai_context.tools, + tool_choice=openai_context.tool_choice, + ) + self.set_llm_adapter(openai_context.get_llm_adapter()) + self._restructure_from_openai_messages() + return self + + @classmethod + def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext": + self = cls(messages=messages) + self._restructure_from_openai_messages() + return self + + @classmethod + def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext": + context = cls() + context.add_image_frame_message( + format=frame.format, size=frame.size, image=frame.image, text=frame.text + ) + return context + + def set_messages(self, messages: List): + self._messages[:] = messages + self._restructure_from_openai_messages() + + # convert a message in AWS Bedrock format into one or more messages in OpenAI format + def to_standard_messages(self, obj): + """Convert AWS Bedrock message format to standard structured format. + + Handles text content and function calls for both user and assistant messages. + + Args: + obj: Message in AWS Bedrock format: + { + "role": "user/assistant", + "content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}] + } + + Returns: + List of messages in standard format: + [ + { + "role": "user/assistant/tool", + "content": [{"type": "text", "text": str}] + } + ] + """ + role = obj.get("role") + content = obj.get("content") + + if role == "assistant": + if isinstance(content, str): + return [{"role": role, "content": [{"type": "text", "text": content}]}] + elif isinstance(content, list): + text_items = [] + tool_items = [] + for item in content: + if "text" in item: + text_items.append({"type": "text", "text": item["text"]}) + elif "toolUse" in item: + tool_use = item["toolUse"] + tool_items.append( + { + "type": "function", + "id": tool_use["toolUseId"], + "function": { + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + ) + messages = [] + if text_items: + messages.append({"role": role, "content": text_items}) + if tool_items: + messages.append({"role": role, "tool_calls": tool_items}) + return messages + elif role == "user": + if isinstance(content, str): + return [{"role": role, "content": [{"type": "text", "text": content}]}] + elif isinstance(content, list): + text_items = [] + tool_items = [] + for item in content: + if "text" in item: + text_items.append({"type": "text", "text": item["text"]}) + elif "toolResult" in item: + tool_result = item["toolResult"] + # Extract content from toolResult + result_content = "" + if isinstance(tool_result["content"], list): + for content_item in tool_result["content"]: + if "text" in content_item: + result_content = content_item["text"] + elif "json" in content_item: + result_content = json.dumps(content_item["json"]) + else: + result_content = tool_result["content"] + + tool_items.append( + { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": result_content, + } + ) + messages = [] + if text_items: + messages.append({"role": role, "content": text_items}) + messages.extend(tool_items) + return messages + + def from_standard_message(self, message): + """Convert standard format message to AWS Bedrock format. + + Handles conversion of text content, tool calls, and tool results. + Empty text content is converted to "(empty)". + + Args: + message: Message in standard format: + { + "role": "user/assistant/tool", + "content": str | [{"type": "text", ...}], + "tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}] + } + + Returns: + Message in AWS Bedrock format: + { + "role": "user/assistant", + "content": [ + {"text": str} | + {"toolUse": {"toolUseId": str, "name": str, "input": dict}} | + {"toolResult": {"toolUseId": str, "content": [...], "status": str}} + ] + } + """ + if message["role"] == "tool": + # Try to parse the content as JSON if it looks like JSON + try: + if message["content"].strip().startswith("{") and message[ + "content" + ].strip().endswith("}"): + content_json = json.loads(message["content"]) + tool_result_content = [{"json": content_json}] + else: + tool_result_content = [{"text": message["content"]}] + except: + tool_result_content = [{"text": message["content"]}] + + return { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": message["tool_call_id"], + "content": tool_result_content, + }, + }, + ], + } + + if message.get("tool_calls"): + tc = message["tool_calls"] + ret = {"role": "assistant", "content": []} + for tool_call in tc: + function = tool_call["function"] + arguments = json.loads(function["arguments"]) + new_tool_use = { + "toolUse": { + "toolUseId": tool_call["id"], + "name": function["name"], + "input": arguments, + } + } + ret["content"].append(new_tool_use) + return ret + + # Handle text content + content = message.get("content") + if isinstance(content, str): + if content == "": + return {"role": message["role"], "content": [{"text": "(empty)"}]} + else: + return {"role": message["role"], "content": [{"text": content}]} + elif isinstance(content, list): + new_content = [] + for item in content: + if item.get("type", "") == "text": + text_content = item["text"] if item["text"] != "" else "(empty)" + new_content.append({"text": text_content}) + return {"role": message["role"], "content": new_content} + + return message + + def add_image_frame_message( + self, *, format: str, size: tuple[int, int], image: bytes, text: str = None + ): + buffer = io.BytesIO() + Image.frombytes(format, size, image).save(buffer, format="JPEG") + encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + + # Image should be the first content block in the message + content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}] + if text: + content.append({"text": text}) + self.add_message({"role": "user", "content": content}) + + def add_message(self, message): + try: + if self.messages: + # AWS Bedrock requires that roles alternate. If this message's + # role is the same as the last message, we should add this + # message's content to the last message. + if self.messages[-1]["role"] == message["role"]: + # if the last message has just a content string, convert it to a list + # in the proper format + if isinstance(self.messages[-1]["content"], str): + self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}] + # if this message has just a content string, convert it to a list + # in the proper format + if isinstance(message["content"], str): + message["content"] = [{"text": message["content"]}] + # append the content of this message to the last message + self.messages[-1]["content"].extend(message["content"]) + else: + self.messages.append(message) + else: + self.messages.append(message) + except Exception as e: + logger.error(f"Error adding message: {e}") + + def _restructure_from_bedrock_messages(self): + """Restructure messages in AWS Bedrock format by handling system + messages, merging consecutive messages with the same role, and ensuring + proper content formatting. + + """ + # Handle system message if present at the beginning + if self.messages and self.messages[0]["role"] == "system": + if len(self.messages) == 1: + self.messages[0]["role"] = "user" + else: + system_content = self.messages.pop(0)["content"] + if isinstance(system_content, str): + system_content = [{"text": system_content}] + + if self.system: + if isinstance(self.system, str): + self.system = [{"text": self.system}] + self.system.extend(system_content) + else: + self.system = system_content + + # Ensure content is properly formatted + for msg in self.messages: + if isinstance(msg["content"], str): + msg["content"] = [{"text": msg["content"]}] + elif not msg["content"]: + msg["content"] = [{"text": "(empty)"}] + elif isinstance(msg["content"], list): + for idx, item in enumerate(msg["content"]): + if isinstance(item, dict) and "text" in item and item["text"] == "": + item["text"] = "(empty)" + elif isinstance(item, str) and item == "": + msg["content"][idx] = {"text": "(empty)"} + + # Merge consecutive messages with the same role + merged_messages = [] + for msg in self.messages: + if merged_messages and merged_messages[-1]["role"] == msg["role"]: + merged_messages[-1]["content"].extend(msg["content"]) + else: + merged_messages.append(msg) + + self.messages.clear() + self.messages.extend(merged_messages) + + def _restructure_from_openai_messages(self): + # first, map across self._messages calling self.from_standard_message(m) to modify messages in place + try: + self._messages[:] = [self.from_standard_message(m) for m in self._messages] + except Exception as e: + logger.error(f"Error mapping messages: {e}") + + # See if we should pull the system message out of our context.messages list. (For + # compatibility with Open AI messages format.) + if self.messages and self.messages[0]["role"] == "system": + self.system = self.messages[0]["content"] + self.messages.pop(0) + + # Merge consecutive messages with the same role. + i = 0 + while i < len(self.messages) - 1: + current_message = self.messages[i] + next_message = self.messages[i + 1] + if current_message["role"] == next_message["role"]: + # Convert content to list of dictionaries if it's a string + if isinstance(current_message["content"], str): + current_message["content"] = [ + {"type": "text", "text": current_message["content"]} + ] + if isinstance(next_message["content"], str): + next_message["content"] = [{"type": "text", "text": next_message["content"]}] + # Concatenate the content + current_message["content"].extend(next_message["content"]) + # Remove the next message from the list + self.messages.pop(i + 1) + else: + i += 1 + + # Avoid empty content in messages + for message in self.messages: + if isinstance(message["content"], str) and message["content"] == "": + message["content"] = "(empty)" + elif isinstance(message["content"], list) and len(message["content"]) == 0: + message["content"] = [{"type": "text", "text": "(empty)"}] + + def get_messages_for_persistent_storage(self): + messages = super().get_messages_for_persistent_storage() + if self.system: + messages.insert(0, {"role": "system", "content": self.system}) + return messages + + def get_messages_for_logging(self) -> str: + msgs = [] + for message in self.messages: + msg = copy.deepcopy(message) + if "content" in msg: + if isinstance(msg["content"], list): + for item in msg["content"]: + if item.get("image"): + item["source"]["bytes"] = "..." + msgs.append(msg) + return json.dumps(msgs) + + +class AWSBedrockUserContextAggregator(LLMUserContextAggregator): + pass + + +class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator): + async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): + # Format tool use according to AWS Bedrock API + self._context.add_message( + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": frame.tool_call_id, + "name": frame.function_name, + "input": frame.arguments if frame.arguments else {}, + } + } + ], + } + ) + self._context.add_message( + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": frame.tool_call_id, + "content": [{"text": "IN_PROGRESS"}], + } + } + ], + } + ) + + async def handle_function_call_result(self, frame: FunctionCallResultFrame): + if frame.result: + result = json.dumps(frame.result) + await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) + else: + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "COMPLETED" + ) + + async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "CANCELLED" + ) + + async def _update_function_call_result( + self, function_name: str, tool_call_id: str, result: Any + ): + for message in self._context.messages: + if message["role"] == "user": + for content in message["content"]: + if ( + isinstance(content, dict) + and content.get("toolResult") + and content["toolResult"]["toolUseId"] == tool_call_id + ): + content["toolResult"]["content"] = [{"text": result}] + + async def handle_user_image_frame(self, frame: UserImageRawFrame): + await self._update_function_call_result( + frame.request.function_name, frame.request.tool_call_id, "COMPLETED" + ) + self._context.add_image_frame_message( + format=frame.format, + size=frame.size, + image=frame.image, + text=frame.request.context, + ) + + +class AWSBedrockLLMService(LLMService): + """This class implements inference with AWS Bedrock models including Amazon + Nova and Anthropic Claude. + + Requires AWS credentials to be configured in the environment or through + boto3 configuration. + + """ + + # Overriding the default adapter to use the Anthropic one. + adapter_class = AWSBedrockLLMAdapter + + class InputParams(BaseModel): + max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1) + temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0) + top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0) + stop_sequences: Optional[List[str]] = Field(default_factory=lambda: []) + latency: Optional[str] = Field(default_factory=lambda: "standard") + additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict) + + def __init__( + self, + *, + aws_access_key: Optional[str] = None, + aws_secret_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region: str = "us-east-1", + model: str, + params: InputParams = InputParams(), + client_config: Optional[Config] = None, + **kwargs, + ): + super().__init__(**kwargs) + + # Initialize the AWS Bedrock client + if not client_config: + client_config = Config( + connect_timeout=300, # 5 minutes + read_timeout=300, # 5 minutes + retries={"max_attempts": 3}, + ) + session = boto3.Session( + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key, + aws_session_token=aws_session_token, + region_name=aws_region, + ) + self._client = session.client(service_name="bedrock-runtime", config=client_config) + + self.set_model_name(model) + self._settings = { + "max_tokens": params.max_tokens, + "temperature": params.temperature, + "top_p": params.top_p, + "latency": params.latency, + "additional_model_request_fields": params.additional_model_request_fields + if isinstance(params.additional_model_request_fields, dict) + else {}, + } + + logger.info(f"Using AWS Bedrock model: {model}") + + def can_generate_metrics(self) -> bool: + return True + + def create_context_aggregator( + self, + context: OpenAILLMContext, + *, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), + ) -> AWSBedrockContextAggregatorPair: + """Create an instance of AWSBedrockContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. + + Returns: + AWSBedrockContextAggregatorPair: A pair of context aggregators, one + for the user and one for the assistant, encapsulated in an + AWSBedrockContextAggregatorPair. + """ + context.set_llm_adapter(self.get_llm_adapter()) + + if isinstance(context, OpenAILLMContext): + context = AWSBedrockLLMContext.from_openai_context(context) + + user = AWSBedrockUserContextAggregator(context, params=user_params) + assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params) + return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant) + + async def _process_context(self, context: AWSBedrockLLMContext): + # Usage tracking + prompt_tokens = 0 + completion_tokens = 0 + completion_tokens_estimate = 0 + cache_read_input_tokens = 0 + cache_creation_input_tokens = 0 + use_completion_tokens_estimate = False + + try: + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + + await self.start_ttfb_metrics() + + # Set up inference config + inference_config = { + "maxTokens": self._settings["max_tokens"], + "temperature": self._settings["temperature"], + "topP": self._settings["top_p"], + } + + # Prepare request parameters + request_params = { + "modelId": self.model_name, + "messages": context.messages, + "inferenceConfig": inference_config, + "additionalModelRequestFields": self._settings["additional_model_request_fields"], + } + + # Add system message + request_params["system"] = context.system + + # Add tools if present + if context.tools: + tool_config = {"tools": context.tools} + + # Add tool_choice if specified + if context.tool_choice: + if context.tool_choice == "auto": + tool_config["toolChoice"] = {"auto": {}} + elif context.tool_choice == "none": + # Skip adding toolChoice for "none" + pass + elif ( + isinstance(context.tool_choice, dict) and "function" in context.tool_choice + ): + tool_config["toolChoice"] = { + "tool": {"name": context.tool_choice["function"]["name"]} + } + + request_params["toolConfig"] = tool_config + + # Add performance config if latency is specified + if self._settings["latency"] in ["standard", "optimized"]: + request_params["performanceConfig"] = {"latency": self._settings["latency"]} + + logger.debug(f"Calling AWS Bedrock model with: {request_params}") + + # Call AWS Bedrock with streaming + response = self._client.converse_stream(**request_params) + + await self.stop_ttfb_metrics() + + # Process the streaming response + tool_use_block = None + json_accumulator = "" + + for event in response["stream"]: + # Handle text content + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + await self.push_frame(LLMTextFrame(delta["text"])) + completion_tokens_estimate += self._estimate_tokens(delta["text"]) + elif "toolUse" in delta and "input" in delta["toolUse"]: + # Handle partial JSON for tool use + json_accumulator += delta["toolUse"]["input"] + completion_tokens_estimate += self._estimate_tokens( + delta["toolUse"]["input"] + ) + + # Handle tool use start + elif "contentBlockStart" in event: + content_block_start = event["contentBlockStart"]["start"] + if "toolUse" in content_block_start: + tool_use_block = { + "id": content_block_start["toolUse"].get("toolUseId", ""), + "name": content_block_start["toolUse"].get("name", ""), + } + json_accumulator = "" + + # Handle message completion with tool use + elif "messageStop" in event and "stopReason" in event["messageStop"]: + if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block: + try: + arguments = json.loads(json_accumulator) if json_accumulator else {} + await self.call_function( + context=context, + tool_call_id=tool_use_block["id"], + function_name=tool_use_block["name"], + arguments=arguments, + ) + except json.JSONDecodeError: + logger.error(f"Failed to parse tool arguments: {json_accumulator}") + + # Handle usage metrics if available + if "metadata" in event and "usage" in event["metadata"]: + usage = event["metadata"]["usage"] + prompt_tokens += usage.get("inputTokens", 0) + completion_tokens += usage.get("outputTokens", 0) + cache_read_input_tokens += usage.get("cacheReadInputTokens", 0) + cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0) + + except asyncio.CancelledError: + # If we're interrupted, we won't get a complete usage report. So set our flag to use the + # token estimate. The reraise the exception so all the processors running in this task + # also get cancelled. + use_completion_tokens_estimate = True + raise + except httpx.TimeoutException: + await self._call_event_handler("on_completion_timeout") + except Exception as e: + logger.exception(f"{self} exception: {e}") + finally: + await self.stop_processing_metrics() + await self.push_frame(LLMFullResponseEndFrame()) + comp_tokens = ( + completion_tokens + if not use_completion_tokens_estimate + else completion_tokens_estimate + ) + await self._report_usage_metrics( + prompt_tokens=prompt_tokens, + completion_tokens=comp_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + ) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + if isinstance(frame, OpenAILLMContextFrame): + context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context) + elif isinstance(frame, LLMMessagesFrame): + context = AWSBedrockLLMContext.from_messages(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + # This is only useful in very simple pipelines because it creates + # a new context. Generally we want a context manager to catch + # UserImageRawFrames coming through the pipeline and add them + # to the context. + context = AWSBedrockLLMContext.from_image_frame(frame) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame.settings) + else: + await self.push_frame(frame, direction) + + if context: + await self._process_context(context) + + def _estimate_tokens(self, text: str) -> int: + return int(len(re.split(r"[^\w]+", text)) * 1.3) + + async def _report_usage_metrics( + self, + prompt_tokens: int, + completion_tokens: int, + cache_read_input_tokens: int, + cache_creation_input_tokens: int, + ): + if prompt_tokens or completion_tokens: + tokens = LLMTokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + ) + await self.start_llm_usage_metrics(tokens) diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py new file mode 100644 index 000000000..a02625f81 --- /dev/null +++ b/src/pipecat/services/aws/stt.py @@ -0,0 +1,329 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import json +import os +import random +import string +from typing import AsyncGenerator, Optional + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url +from pipecat.services.stt_service import STTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.") + raise Exception(f"Missing module: {e}") + + +class AWSTranscribeSTTService(STTService): + def __init__( + self, + *, + api_key: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_session_token: Optional[str] = None, + region: Optional[str] = "us-east-1", + sample_rate: int = 16000, + language: Language = Language.EN, + **kwargs, + ): + super().__init__(**kwargs) + + self._settings = { + "sample_rate": sample_rate, + "language": language, + "media_encoding": "linear16", # AWS expects raw PCM + "number_of_channels": 1, + "show_speaker_label": False, + "enable_channel_identification": False, + } + + # Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz + if sample_rate not in [8000, 16000]: + logger.warning( + f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz." + ) + self._settings["sample_rate"] = 16000 + + self._credentials = { + "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"), + "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + "region": region or os.getenv("AWS_REGION", "us-east-1"), + } + + self._ws_client = None + self._connection_lock = asyncio.Lock() + self._connecting = False + self._receive_task = None + + def get_service_encoding(self, encoding: str) -> str: + """Convert internal encoding format to AWS Transcribe format.""" + encoding_map = { + "linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM + } + return encoding_map.get(encoding, encoding) + + async def start(self, frame: StartFrame): + """Initialize the connection when the service starts.""" + await super().start(frame) + logger.info("Starting AWS Transcribe service...") + retry_count = 0 + max_retries = 3 + + while retry_count < max_retries: + try: + await self._connect() + if self._ws_client and self._ws_client.open: + logger.info("Successfully established WebSocket connection") + return + logger.warning("WebSocket connection not established after connect") + except Exception as e: + logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}") + retry_count += 1 + if retry_count < max_retries: + await asyncio.sleep(1) # Wait before retrying + + raise RuntimeError("Failed to establish WebSocket connection after multiple attempts") + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Process audio data and send to AWS Transcribe""" + try: + # Ensure WebSocket is connected + if not self._ws_client or not self._ws_client.open: + logger.debug("WebSocket not connected, attempting to reconnect...") + try: + await self._connect() + except Exception as e: + logger.error(f"Failed to reconnect: {e}") + yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False) + return + + # Format the audio data according to AWS event stream format + event_message = build_event_message(audio) + + # Send the formatted event message + try: + await self._ws_client.send(event_message) + # Start metrics after first chunk sent + await self.start_processing_metrics() + await self.start_ttfb_metrics() + except websockets.exceptions.ConnectionClosed as e: + logger.warning(f"Connection closed while sending: {e}") + await self._disconnect() + # Don't yield error here - we'll retry on next frame + except Exception as e: + logger.error(f"Error sending audio: {e}") + yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False) + await self._disconnect() + + except Exception as e: + logger.error(f"Error in run_stt: {e}") + yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False) + await self._disconnect() + + async def _connect(self): + """Connect to AWS Transcribe with connection state management.""" + if self._ws_client and self._ws_client.open and self._receive_task: + logger.debug(f"{self} Already connected") + return + + async with self._connection_lock: + if self._connecting: + logger.debug(f"{self} Connection already in progress") + return + + try: + self._connecting = True + logger.debug(f"{self} Starting connection process...") + + if self._ws_client: + await self._disconnect() + + language_code = self.language_to_service_language( + Language(self._settings["language"]) + ) + if not language_code: + raise ValueError(f"Unsupported language: {self._settings['language']}") + + # Generate random websocket key + websocket_key = "".join( + random.choices( + string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20 + ) + ) + + # Add required headers + extra_headers = { + "Origin": "https://localhost", + "Sec-WebSocket-Key": websocket_key, + "Sec-WebSocket-Version": "13", + "Connection": "keep-alive", + } + + # Get presigned URL + presigned_url = get_presigned_url( + region=self._credentials["region"], + credentials={ + "access_key": self._credentials["aws_access_key_id"], + "secret_key": self._credentials["aws_secret_access_key"], + "session_token": self._credentials["aws_session_token"], + }, + language_code=language_code, + media_encoding=self.get_service_encoding( + self._settings["media_encoding"] + ), # Convert to AWS format + sample_rate=self._settings["sample_rate"], + number_of_channels=self._settings["number_of_channels"], + enable_partial_results_stabilization=True, + partial_results_stability="high", + show_speaker_label=self._settings["show_speaker_label"], + enable_channel_identification=self._settings["enable_channel_identification"], + ) + + logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...") + + # Connect with the required headers and settings + self._ws_client = await websockets.connect( + presigned_url, + extra_headers=extra_headers, + subprotocols=["mqtt"], + ping_interval=None, + ping_timeout=None, + compression=None, + ) + + logger.debug(f"{self} WebSocket connected, starting receive task...") + + # Start receive task + self._receive_task = self.create_task(self._receive_loop()) + + logger.info(f"{self} Successfully connected to AWS Transcribe") + + except Exception as e: + logger.error(f"{self} Failed to connect to AWS Transcribe: {e}") + await self._disconnect() + raise + + finally: + self._connecting = False + + async def _disconnect(self): + """Disconnect from AWS Transcribe.""" + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + try: + if self._ws_client and self._ws_client.open: + # Send end-stream message + end_stream = {"message-type": "event", "event": "end"} + await self._ws_client.send(json.dumps(end_stream)) + await self._ws_client.close() + except Exception as e: + logger.warning(f"{self} Error closing WebSocket connection: {e}") + finally: + self._ws_client = None + + def language_to_service_language(self, language: Language) -> str | None: + """Convert internal language enum to AWS Transcribe language code.""" + language_map = { + Language.EN: "en-US", + Language.ES: "es-US", + Language.FR: "fr-FR", + Language.DE: "de-DE", + Language.IT: "it-IT", + Language.PT: "pt-BR", + Language.JA: "ja-JP", + Language.KO: "ko-KR", + Language.ZH: "zh-CN", + } + return language_map.get(language) + + async def _receive_loop(self): + """Background task to receive and process messages from AWS Transcribe.""" + while True: + if not self._ws_client or not self._ws_client.open: + logger.warning(f"{self} WebSocket closed in receive loop") + break + + try: + response = await self._ws_client.recv() + headers, payload = decode_event(response) + + if headers.get(":message-type") == "event": + # Process transcription results + results = payload.get("Transcript", {}).get("Results", []) + if results: + result = results[0] + alternatives = result.get("Alternatives", []) + if alternatives: + transcript = alternatives[0].get("Transcript", "") + is_final = not result.get("IsPartial", True) + + if transcript: + await self.stop_ttfb_metrics() + if is_final: + await self.push_frame( + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._settings["language"], + ) + ) + await self.stop_processing_metrics() + else: + await self.push_frame( + InterimTranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._settings["language"], + ) + ) + elif headers.get(":message-type") == "exception": + error_msg = payload.get("Message", "Unknown error") + logger.error(f"{self} Exception from AWS: {error_msg}") + await self.push_frame( + ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False) + ) + else: + logger.debug(f"{self} Other message type received: {headers}") + logger.debug(f"{self} Payload: {payload}") + except websockets.exceptions.ConnectionClosed as e: + logger.error( + f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}" + ) + break + except Exception as e: + logger.error(f"{self} Unexpected error in receive loop: {e}") + break diff --git a/src/pipecat/services/aws/tts.py b/src/pipecat/services/aws/tts.py index db6e168ab..40d746514 100644 --- a/src/pipecat/services/aws/tts.py +++ b/src/pipecat/services/aws/tts.py @@ -5,6 +5,7 @@ # import asyncio +import os from typing import AsyncGenerator, Optional from loguru import logger @@ -26,9 +27,7 @@ try: from botocore.exceptions import BotoCoreError, ClientError except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." - ) + logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.") raise Exception(f"Missing module: {e}") @@ -108,7 +107,7 @@ def language_to_aws_language(language: Language) -> Optional[str]: return language_map.get(language) -class PollyTTSService(TTSService): +class AWSPollyTTSService(TTSService): class InputParams(BaseModel): engine: Optional[str] = None language: Optional[Language] = Language.EN @@ -151,6 +150,24 @@ class PollyTTSService(TTSService): self.set_voice(voice_id) + # Get credentials from environment variables if not provided + self._credentials = { + "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"), + "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + "region": region or os.getenv("AWS_REGION", "us-east-1"), + } + + # Validate that we have the required credentials + if ( + not self._credentials["aws_access_key_id"] + or not self._credentials["aws_secret_access_key"] + ): + raise ValueError( + "AWS credentials not found. Please provide them either through constructor parameters " + "or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables." + ) + def can_generate_metrics(self) -> bool: return True @@ -165,18 +182,17 @@ class PollyTTSService(TTSService): prosody_attrs = [] # Prosody tags are only supported for standard and neural engines - if self._settings["engine"] != "generative": - if self._settings["rate"]: - prosody_attrs.append(f"rate='{self._settings['rate']}'") + if self._settings["engine"] == "standard": if self._settings["pitch"]: prosody_attrs.append(f"pitch='{self._settings['pitch']}'") - if self._settings["volume"]: - prosody_attrs.append(f"volume='{self._settings['volume']}'") - if prosody_attrs: - ssml += f"" - else: - logger.warning("Prosody tags are not supported for generative engine. Ignoring.") + if self._settings["rate"]: + prosody_attrs.append(f"rate='{self._settings['rate']}'") + if self._settings["volume"]: + prosody_attrs.append(f"volume='{self._settings['volume']}'") + + if prosody_attrs: + ssml += f"" ssml += text @@ -187,6 +203,8 @@ class PollyTTSService(TTSService): ssml += "" + logger.trace(f"{self} SSML: {ssml}") + return ssml async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: @@ -248,3 +266,17 @@ class PollyTTSService(TTSService): finally: yield TTSStoppedFrame() + + +class PollyTTSService(AWSPollyTTSService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.", + DeprecationWarning, + ) diff --git a/src/pipecat/services/aws/utils.py b/src/pipecat/services/aws/utils.py new file mode 100644 index 000000000..db69456e9 --- /dev/null +++ b/src/pipecat/services/aws/utils.py @@ -0,0 +1,261 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import binascii +import datetime +import hashlib +import hmac +import json +import struct +import urllib.parse +from typing import Dict, Optional + + +def get_presigned_url( + *, + region: str, + credentials: Dict[str, Optional[str]], + language_code: str, + media_encoding: str = "pcm", + sample_rate: int = 16000, + number_of_channels: int = 1, + enable_partial_results_stabilization: bool = True, + partial_results_stability: str = "high", + vocabulary_name: Optional[str] = None, + vocabulary_filter_name: Optional[str] = None, + show_speaker_label: bool = False, + enable_channel_identification: bool = False, +) -> str: + """Create a presigned URL for AWS Transcribe streaming.""" + access_key = credentials.get("access_key") + secret_key = credentials.get("secret_key") + session_token = credentials.get("session_token") + + if not access_key or not secret_key: + raise ValueError("AWS credentials are required") + + # Initialize the URL generator + url_generator = AWSTranscribePresignedURL( + access_key=access_key, secret_key=secret_key, session_token=session_token, region=region + ) + + # Get the presigned URL + return url_generator.get_request_url( + sample_rate=sample_rate, + language_code=language_code, + media_encoding=media_encoding, + vocabulary_name=vocabulary_name, + vocabulary_filter_name=vocabulary_filter_name, + show_speaker_label=show_speaker_label, + enable_channel_identification=enable_channel_identification, + number_of_channels=number_of_channels, + enable_partial_results_stabilization=enable_partial_results_stabilization, + partial_results_stability=partial_results_stability, + ) + + +class AWSTranscribePresignedURL: + def __init__( + self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1" + ): + self.access_key = access_key + self.secret_key = secret_key + self.session_token = session_token + self.method = "GET" + self.service = "transcribe" + self.region = region + self.endpoint = "" + self.host = "" + self.amz_date = "" + self.datestamp = "" + self.canonical_uri = "/stream-transcription-websocket" + self.canonical_headers = "" + self.signed_headers = "host" + self.algorithm = "AWS4-HMAC-SHA256" + self.credential_scope = "" + self.canonical_querystring = "" + self.payload_hash = "" + self.canonical_request = "" + self.string_to_sign = "" + self.signature = "" + self.request_url = "" + + def get_request_url( + self, + sample_rate: int, + language_code: str = "", + media_encoding: str = "pcm", + vocabulary_name: str = "", + vocabulary_filter_name: str = "", + show_speaker_label: bool = False, + enable_channel_identification: bool = False, + number_of_channels: int = 1, + enable_partial_results_stabilization: bool = False, + partial_results_stability: str = "", + ) -> str: + self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443" + self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443" + + now = datetime.datetime.utcnow() + self.amz_date = now.strftime("%Y%m%dT%H%M%SZ") + self.datestamp = now.strftime("%Y%m%d") + self.canonical_headers = f"host:{self.host}\n" + self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request" + + # Create canonical querystring + self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm + self.canonical_querystring += ( + "&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope + ) + self.canonical_querystring += "&X-Amz-Date=" + self.amz_date + self.canonical_querystring += "&X-Amz-Expires=300" + if self.session_token: + self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote( + self.session_token, safe="" + ) + self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers + + if enable_channel_identification: + self.canonical_querystring += "&enable-channel-identification=true" + if enable_partial_results_stabilization: + self.canonical_querystring += "&enable-partial-results-stabilization=true" + if language_code: + self.canonical_querystring += "&language-code=" + language_code + if media_encoding: + self.canonical_querystring += "&media-encoding=" + media_encoding + if number_of_channels > 1: + self.canonical_querystring += "&number-of-channels=" + str(number_of_channels) + if partial_results_stability: + self.canonical_querystring += "&partial-results-stability=" + partial_results_stability + if sample_rate: + self.canonical_querystring += "&sample-rate=" + str(sample_rate) + if show_speaker_label: + self.canonical_querystring += "&show-speaker-label=true" + if vocabulary_filter_name: + self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name + if vocabulary_name: + self.canonical_querystring += "&vocabulary-name=" + vocabulary_name + + # Create payload hash + self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest() + + # Create canonical request + self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}" + + # Create string to sign + credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request" + string_to_sign = ( + f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n" + + hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest() + ) + + # Calculate signature + k_date = hmac.new( + f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256 + ).digest() + k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest() + k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest() + k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest() + self.signature = hmac.new( + k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + # Add signature to query string + self.canonical_querystring += "&X-Amz-Signature=" + self.signature + + # Create request URL + self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring + return self.request_url + + +def get_headers(header_name: str, header_value: str) -> bytearray: + """Build a header following AWS event stream format.""" + name = header_name.encode("utf-8") + name_byte_length = bytes([len(name)]) + value_type = bytes([7]) # 7 represents a string + value = header_value.encode("utf-8") + value_byte_length = struct.pack(">H", len(value)) + + # Construct the header + header_list = bytearray() + header_list.extend(name_byte_length) + header_list.extend(name) + header_list.extend(value_type) + header_list.extend(value_byte_length) + header_list.extend(value) + return header_list + + +def build_event_message(payload: bytes) -> bytes: + """ + Build an event message for AWS Transcribe streaming. + Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py + """ + # Build headers + content_type_header = get_headers(":content-type", "application/octet-stream") + event_type_header = get_headers(":event-type", "AudioEvent") + message_type_header = get_headers(":message-type", "event") + + headers = bytearray() + headers.extend(content_type_header) + headers.extend(event_type_header) + headers.extend(message_type_header) + + # Calculate total byte length and headers byte length + # 16 accounts for 8 byte prelude, 2x 4 byte CRCs + total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16) + headers_byte_length = struct.pack(">I", len(headers)) + + # Build the prelude + prelude = bytearray([0] * 8) + prelude[:4] = total_byte_length + prelude[4:] = headers_byte_length + + # Calculate checksum for prelude + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + + # Construct the message + message_as_list = bytearray() + message_as_list.extend(prelude) + message_as_list.extend(prelude_crc) + message_as_list.extend(headers) + message_as_list.extend(payload) + + # Calculate checksum for message + message = bytes(message_as_list) + message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF) + + # Add message checksum + message_as_list.extend(message_crc) + + return bytes(message_as_list) + + +def decode_event(message): + # Extract the prelude, headers, payload and CRC + prelude = message[:8] + total_length, headers_length = struct.unpack(">II", prelude) + prelude_crc = struct.unpack(">I", message[8:12])[0] + headers = message[12 : 12 + headers_length] + payload = message[12 + headers_length : -4] + message_crc = struct.unpack(">I", message[-4:])[0] + + # Check the CRCs + assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed" + assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed" + + # Parse the headers + headers_dict = {} + while headers: + name_len = headers[0] + name = headers[1 : 1 + name_len].decode("utf-8") + value_type = headers[1 + name_len] + value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0] + value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8") + headers_dict[name] = value + headers = headers[4 + name_len + value_len :] + + return headers_dict, json.loads(payload) diff --git a/src/pipecat/services/aws_nova_sonic/__init__.py b/src/pipecat/services/aws_nova_sonic/__init__.py new file mode 100644 index 000000000..e14c44f8a --- /dev/null +++ b/src/pipecat/services/aws_nova_sonic/__init__.py @@ -0,0 +1 @@ +from .aws import AWSNovaSonicLLMService diff --git a/src/pipecat/services/aws_nova_sonic/aws.py b/src/pipecat/services/aws_nova_sonic/aws.py new file mode 100644 index 000000000..410481065 --- /dev/null +++ b/src/pipecat/services/aws_nova_sonic/aws.py @@ -0,0 +1,997 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import base64 +import json +import time +import uuid +import wave +from dataclasses import dataclass +from enum import Enum +from importlib.resources import files +from typing import Any, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field + +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter +from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + CancelFrame, + EndFrame, + Frame, + InputAudioRawFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMTextFrame, + StartFrame, + TranscriptionFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, + TTSTextFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMUserAggregatorParams, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.aws_nova_sonic.context import ( + AWSNovaSonicAssistantContextAggregator, + AWSNovaSonicContextAggregatorPair, + AWSNovaSonicLLMContext, + AWSNovaSonicUserContextAggregator, + Role, +) +from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame +from pipecat.services.llm_service import LLMService +from pipecat.utils.time import time_now_iso8601 + +try: + from aws_sdk_bedrock_runtime.client import ( + BedrockRuntimeClient, + InvokeModelWithBidirectionalStreamOperationInput, + ) + from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme + from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInput, + InvokeModelWithBidirectionalStreamInputChunk, + InvokeModelWithBidirectionalStreamOperationOutput, + InvokeModelWithBidirectionalStreamOutput, + ) + from smithy_aws_core.credentials_resolvers.static import StaticCredentialsResolver + from smithy_aws_core.identity import AWSCredentialsIdentity + from smithy_core.aio.eventstream import DuplexEventStream +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use AWS services, you need to `pip install pipecat-ai[aws-nova-sonic]`." + ) + raise Exception(f"Missing module: {e}") + + +class AWSNovaSonicUnhandledFunctionException(Exception): + pass + + +class ContentType(Enum): + AUDIO = "AUDIO" + TEXT = "TEXT" + TOOL = "TOOL" + + +class TextStage(Enum): + FINAL = "FINAL" # what has been said + SPECULATIVE = "SPECULATIVE" # what's planned to be said + + +@dataclass +class CurrentContent: + type: ContentType + role: Role + text_stage: TextStage # None if not text + text_content: str # starts as None, then fills in if text + + def __str__(self): + return ( + f"CurrentContent(\n" + f" type={self.type.name},\n" + f" role={self.role.name},\n" + f" text_stage={self.text_stage.name if self.text_stage else 'None'}\n" + f")" + ) + + +class Params(BaseModel): + # Audio input + input_sample_rate: Optional[int] = Field(default=16000) + input_sample_size: Optional[int] = Field(default=16) + input_channel_count: Optional[int] = Field(default=1) + + # Audio output + output_sample_rate: Optional[int] = Field(default=24000) + output_sample_size: Optional[int] = Field(default=16) + output_channel_count: Optional[int] = Field(default=1) + + # Inference + max_tokens: Optional[int] = Field(default=1024) + top_p: Optional[float] = Field(default=0.9) + temperature: Optional[float] = Field(default=0.7) + + +class AWSNovaSonicLLMService(LLMService): + # Override the default adapter to use the AWSNovaSonicLLMAdapter one + adapter_class = AWSNovaSonicLLMAdapter + + def __init__( + self, + *, + secret_access_key: str, + access_key_id: str, + region: str, + model: str = "amazon.nova-sonic-v1:0", + voice_id: str = "matthew", # matthew, tiffany, amy + params: Params = Params(), + system_instruction: Optional[str] = None, + tools: Optional[ToolsSchema] = None, + send_transcription_frames: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._secret_access_key = secret_access_key + self._access_key_id = access_key_id + self._region = region + self._model = model + self._client: Optional[BedrockRuntimeClient] = None + self._voice_id = voice_id + self._params = params + self._system_instruction = system_instruction + self._tools = tools + self._send_transcription_frames = send_transcription_frames + self._context: Optional[AWSNovaSonicLLMContext] = None + self._stream: Optional[ + DuplexEventStream[ + InvokeModelWithBidirectionalStreamInput, + InvokeModelWithBidirectionalStreamOutput, + InvokeModelWithBidirectionalStreamOperationOutput, + ] + ] = None + self._receive_task: Optional[asyncio.Task] = None + self._prompt_name: Optional[str] = None + self._input_audio_content_name: Optional[str] = None + self._content_being_received: Optional[CurrentContent] = None + self._assistant_is_responding = False + self._ready_to_send_context = False + self._handling_bot_stopped_speaking = False + self._triggering_assistant_response = False + self._assistant_response_trigger_audio: Optional[bytes] = ( + None # Not cleared on _disconnect() + ) + self._disconnecting = False + self._connected_time: Optional[float] = None + self._wants_connection = False + + # + # standard AIService frame handling + # + + async def start(self, frame: StartFrame): + await super().start(frame) + self._wants_connection = True + await self._start_connecting() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + self._wants_connection = False + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + self._wants_connection = False + await self._disconnect() + + # + # conversation resetting + # + + async def reset_conversation(self): + logger.debug("Resetting conversation") + await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=False) + + # Carry over previous context through disconnect + context = self._context + await self._disconnect() + self._context = context + + await self._start_connecting() + + # + # frame processing + # + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, OpenAILLMContextFrame): + await self._handle_context(frame.context) + elif isinstance(frame, InputAudioRawFrame): + await self._handle_input_audio_frame(frame) + elif isinstance(frame, BotStoppedSpeakingFrame): + await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=True) + elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame): + await self._handle_function_call_result(frame) + + await self.push_frame(frame, direction) + + async def _handle_context(self, context: OpenAILLMContext): + if not self._context: + # We got our initial context - try to finish connecting + self._context = AWSNovaSonicLLMContext.upgrade_to_nova_sonic( + context, self._system_instruction + ) + await self._finish_connecting_if_context_available() + + async def _handle_input_audio_frame(self, frame: InputAudioRawFrame): + # Wait until we're done sending the assistant response trigger audio before sending audio + # from the user's mic + if self._triggering_assistant_response: + return + + await self._send_user_audio_event(frame.audio) + + async def _handle_bot_stopped_speaking(self, delay_to_catch_trailing_assistant_text: bool): + # Protect against back-to-back BotStoppedSpeaking calls, which I've observed + if self._handling_bot_stopped_speaking: + return + self._handling_bot_stopped_speaking = True + + async def finalize_assistant_response(): + if self._assistant_is_responding: + # Consider the assistant finished with their response (possibly after a short delay, + # to allow for any trailing FINAL assistant text block to come in that need to make + # it into context). + # + # TODO: ideally we could base this solely on the LLM output events, but I couldn't + # figure out a reliable way to determine when we've gotten our last FINAL text block + # after the LLM is done talking. + # + # First I looked at stopReason, but it doesn't seem like the last FINAL text block + # is reliably marked END_TURN (sometimes the *first* one is, but not the last... + # bug?) + # + # Then I considered schemes where we tally or match up SPECULATIVE text blocks with + # FINAL text blocks to know how many or which FINAL blocks to expect, but user + # interruptions throw a wrench in these schemes: depending on the exact timing of + # the interruption, we should or shouldn't expect some FINAL blocks. + if delay_to_catch_trailing_assistant_text: + # This delay length is a balancing act between "catching" trailing assistant + # text that is quite delayed but not waiting so long that user text comes in + # first and results in a bit of context message order scrambling. + await asyncio.sleep(1.25) + self._assistant_is_responding = False + await self._report_assistant_response_ended() + + self._handling_bot_stopped_speaking = False + + # Finalize the assistant response, either now or after a delay + if delay_to_catch_trailing_assistant_text: + self.create_task(finalize_assistant_response()) + else: + await finalize_assistant_response() + + async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame): + result = frame.result_frame + await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result) + + # + # LLM communication: lifecycle + # + + async def _start_connecting(self): + try: + logger.info("Connecting...") + + if self._client: + # Here we assume that if we have a client we are connected or connecting + return + + # Set IDs for the connection + self._prompt_name = str(uuid.uuid4()) + self._input_audio_content_name = str(uuid.uuid4()) + + # Create the client + self._client = self._create_client() + + # Start the bidirectional stream + self._stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self._model) + ) + + # Send session start event + await self._send_session_start_event() + + # Finish connecting + self._ready_to_send_context = True + await self._finish_connecting_if_context_available() + except Exception as e: + logger.error(f"{self} initialization error: {e}") + self._disconnect() + + async def _finish_connecting_if_context_available(self): + # We can only finish connecting once we've gotten our initial context and we're ready to + # send it + if not (self._context and self._ready_to_send_context): + return + + logger.info("Finishing connecting (setting up session)...") + + # Read context + history = self._context.get_messages_for_initializing_history() + + # Send prompt start event, specifying tools. + # Tools from context take priority over self._tools. + tools = ( + self._context.tools + if self._context.tools + else self.get_llm_adapter().from_standard_tools(self._tools) + ) + logger.debug(f"Using tools: {tools}") + await self._send_prompt_start_event(tools) + + # Send system instruction. + # Instruction from context takes priority over self._system_instruction. + # (NOTE: this prioritizing occurred automatically behind the scenes: the context was + # initialized with self._system_instruction and then updated itself from its messages when + # get_messages_for_initializing_history() was called). + logger.debug(f"Using system instruction: {history.system_instruction}") + if history.system_instruction: + await self._send_text_event(text=history.system_instruction, role=Role.SYSTEM) + + # Send conversation history + for message in history.messages: + await self._send_text_event(text=message.text, role=message.role) + + # Start audio input + await self._send_audio_input_start_event() + + # Start receiving events + self._receive_task = self.create_task(self._receive_task_handler()) + + # Record finished connecting time (must be done before sending assistant response trigger) + self._connected_time = time.time() + + logger.info("Finished connecting") + + # If we need to, send assistant response trigger (depends on self._connected_time) + if self._triggering_assistant_response: + await self._send_assistant_response_trigger() + self._triggering_assistant_response = False + + async def _disconnect(self): + try: + logger.info("Disconnecting...") + + # NOTE: see explanation of HACK, below + self._disconnecting = True + + # Clean up client + if self._client: + await self._send_session_end_events() + self._client = None + + # Clean up stream + if self._stream: + await self._stream.input_stream.close() + self._stream = None + + # NOTE: see explanation of HACK, below + await asyncio.sleep(1) + + # Clean up receive task + # HACK: we should ideally be able to cancel the receive task before stopping the input + # stream, above (meaning we wouldn't need self._disconnecting). But for some reason if + # we don't close the input stream and wait a second first, we're getting an error a lot + # like this one: https://github.com/awslabs/amazon-transcribe-streaming-sdk/issues/61. + if self._receive_task: + await self.cancel_task(self._receive_task, timeout=1.0) + self._receive_task = None + + # Reset remaining connection-specific state + self._prompt_name = None + self._input_audio_content_name = None + self._content_being_received = None + self._assistant_is_responding = False + self._ready_to_send_context = False + self._handling_bot_stopped_speaking = False + self._triggering_assistant_response = False + self._disconnecting = False + self._connected_time = None + + logger.info("Finished disconnecting") + except Exception as e: + logger.error(f"{self} error disconnecting: {e}") + + def _create_client(self) -> BedrockRuntimeClient: + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self._region}.amazonaws.com", + region=self._region, + aws_credentials_identity_resolver=StaticCredentialsResolver( + credentials=AWSCredentialsIdentity( + access_key_id=self._access_key_id, secret_access_key=self._secret_access_key + ) + ), + http_auth_scheme_resolver=HTTPAuthSchemeResolver(), + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, + ) + return BedrockRuntimeClient(config=config) + + # + # LLM communication: input events (pipecat -> LLM) + # + + async def _send_session_start_event(self): + session_start = f""" + {{ + "event": {{ + "sessionStart": {{ + "inferenceConfiguration": {{ + "maxTokens": {self._params.max_tokens}, + "topP": {self._params.top_p}, + "temperature": {self._params.temperature} + }} + }} + }} + }} + """ + await self._send_client_event(session_start) + + async def _send_prompt_start_event(self, tools: List[Any]): + if not self._prompt_name: + return + + tools_config = ( + f""", + "toolUseOutputConfiguration": {{ + "mediaType": "application/json" + }}, + "toolConfiguration": {{ + "tools": {json.dumps(tools)} + }} + """ + if tools + else "" + ) + + prompt_start = f''' + {{ + "event": {{ + "promptStart": {{ + "promptName": "{self._prompt_name}", + "textOutputConfiguration": {{ + "mediaType": "text/plain" + }}, + "audioOutputConfiguration": {{ + "mediaType": "audio/lpcm", + "sampleRateHertz": {self._params.output_sample_rate}, + "sampleSizeBits": {self._params.output_sample_size}, + "channelCount": {self._params.output_channel_count}, + "voiceId": "{self._voice_id}", + "encoding": "base64", + "audioType": "SPEECH" + }}{tools_config} + }} + }} + }} + ''' + await self._send_client_event(prompt_start) + + async def _send_audio_input_start_event(self): + if not self._prompt_name: + return + + audio_content_start = f''' + {{ + "event": {{ + "contentStart": {{ + "promptName": "{self._prompt_name}", + "contentName": "{self._input_audio_content_name}", + "type": "AUDIO", + "interactive": true, + "role": "USER", + "audioInputConfiguration": {{ + "mediaType": "audio/lpcm", + "sampleRateHertz": {self._params.input_sample_rate}, + "sampleSizeBits": {self._params.input_sample_size}, + "channelCount": {self._params.input_channel_count}, + "audioType": "SPEECH", + "encoding": "base64" + }} + }} + }} + }} + ''' + await self._send_client_event(audio_content_start) + + async def _send_text_event(self, text: str, role: Role): + if not self._stream or not self._prompt_name or not text: + return + + content_name = str(uuid.uuid4()) + + text_content_start = f''' + {{ + "event": {{ + "contentStart": {{ + "promptName": "{self._prompt_name}", + "contentName": "{content_name}", + "type": "TEXT", + "interactive": true, + "role": "{role.value}", + "textInputConfiguration": {{ + "mediaType": "text/plain" + }} + }} + }} + }} + ''' + await self._send_client_event(text_content_start) + + escaped_text = json.dumps(text) # includes quotes + text_input = f''' + {{ + "event": {{ + "textInput": {{ + "promptName": "{self._prompt_name}", + "contentName": "{content_name}", + "content": {escaped_text} + }} + }} + }} + ''' + await self._send_client_event(text_input) + + text_content_end = f''' + {{ + "event": {{ + "contentEnd": {{ + "promptName": "{self._prompt_name}", + "contentName": "{content_name}" + }} + }} + }} + ''' + await self._send_client_event(text_content_end) + + async def _send_user_audio_event(self, audio: bytes): + if not self._stream: + return + + blob = base64.b64encode(audio) + audio_event = f''' + {{ + "event": {{ + "audioInput": {{ + "promptName": "{self._prompt_name}", + "contentName": "{self._input_audio_content_name}", + "content": "{blob.decode("utf-8")}" + }} + }} + }} + ''' + await self._send_client_event(audio_event) + + async def _send_session_end_events(self): + if not self._stream or not self._prompt_name: + return + + prompt_end = f''' + {{ + "event": {{ + "promptEnd": {{ + "promptName": "{self._prompt_name}" + }} + }} + }} + ''' + await self._send_client_event(prompt_end) + + session_end = """ + { + "event": { + "sessionEnd": {} + } + } + """ + await self._send_client_event(session_end) + + async def _send_tool_result(self, tool_call_id, result): + if not self._stream or not self._prompt_name: + return + + content_name = str(uuid.uuid4()) + + result_content_start = f''' + {{ + "event": {{ + "contentStart": {{ + "promptName": "{self._prompt_name}", + "contentName": "{content_name}", + "interactive": false, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": {{ + "toolUseId": "{tool_call_id}", + "type": "TEXT", + "textInputConfiguration": {{ + "mediaType": "text/plain" + }} + }} + }} + }} + }} + ''' + await self._send_client_event(result_content_start) + + result_content = json.dumps( + { + "event": { + "toolResult": { + "promptName": self._prompt_name, + "contentName": content_name, + "content": json.dumps(result) if isinstance(result, dict) else result, + } + } + } + ) + await self._send_client_event(result_content) + + result_content_end = f""" + {{ + "event": {{ + "contentEnd": {{ + "promptName": "{self._prompt_name}", + "contentName": "{content_name}" + }} + }} + }} + """ + await self._send_client_event(result_content_end) + + async def _send_client_event(self, event_json: str): + if not self._stream: # should never happen + return + + event = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8")) + ) + await self._stream.input_stream.send(event) + + # + # LLM communication: output events (LLM -> pipecat) + # + + # Receive events for the session. + # A few different kinds of content can be delivered: + # - Transcription of user audio + # - Tool use + # - Text preview of planned response speech before audio delivered + # - User interruption notification + # - Text of response speech that whose audio was actually delivered + # - Audio of response speech + # Each piece of content is wrapped by "contentStart" and "contentEnd" events. The content is + # delivered sequentially: one piece of content will end before another starts. + # The overall completion is wrapped by "completionStart" and "completionEnd" events. + async def _receive_task_handler(self): + try: + while self._stream and not self._disconnecting: + output = await self._stream.await_output() + result = await output[1].receive() + + if result.value and result.value.bytes_: + response_data = result.value.bytes_.decode("utf-8") + json_data = json.loads(response_data) + + if "event" in json_data: + event_json = json_data["event"] + if "completionStart" in event_json: + # Handle the LLM completion starting + await self._handle_completion_start_event(event_json) + elif "contentStart" in event_json: + # Handle a piece of content starting + await self._handle_content_start_event(event_json) + elif "textOutput" in event_json: + # Handle text output content + await self._handle_text_output_event(event_json) + elif "audioOutput" in event_json: + # Handle audio output content + await self._handle_audio_output_event(event_json) + elif "toolUse" in event_json: + # Handle tool use + await self._handle_tool_use_event(event_json) + elif "contentEnd" in event_json: + # Handle a piece of content ending + await self._handle_content_end_event(event_json) + elif "completionEnd" in event_json: + # Handle the LLM completion ending + await self._handle_completion_end_event(event_json) + + except Exception as e: + logger.error(f"{self} error processing responses: {e}") + if self._wants_connection: + await self.reset_conversation() + + async def _handle_completion_start_event(self, event_json): + pass + + async def _handle_content_start_event(self, event_json): + content_start = event_json["contentStart"] + type = content_start["type"] + role = content_start["role"] + generation_stage = None + if "additionalModelFields" in content_start: + additional_model_fields = json.loads(content_start["additionalModelFields"]) + generation_stage = additional_model_fields.get("generationStage") + + # Bookkeeping: track current content being received + content = CurrentContent( + type=ContentType(type), + role=Role(role), + text_stage=TextStage(generation_stage) if generation_stage else None, + text_content=None, + ) + self._content_being_received = content + + if content.role == Role.ASSISTANT: + if content.type == ContentType.AUDIO: + # Note that an assistant response can comprise of multiple audio blocks + if not self._assistant_is_responding: + # The assistant has started responding. + self._assistant_is_responding = True + await self._report_assistant_response_started() + + async def _handle_text_output_event(self, event_json): + if not self._content_being_received: # should never happen + return + content = self._content_being_received + + text_content = event_json["textOutput"]["content"] + + # Bookkeeping: augment the current content being received with text + # Assumption: only one text content per content block + content.text_content = text_content + + async def _handle_audio_output_event(self, event_json): + if not self._content_being_received: # should never happen + return + + # Get audio + audio_content = event_json["audioOutput"]["content"] + + # Push audio frame + audio = base64.b64decode(audio_content) + frame = TTSAudioRawFrame( + audio=audio, + sample_rate=self._params.output_sample_rate, + num_channels=self._params.output_channel_count, + ) + await self.push_frame(frame) + + async def _handle_tool_use_event(self, event_json): + if not self._content_being_received or not self._context: # should never happen + return + + # Get tool use details + tool_use = event_json["toolUse"] + function_name = tool_use["toolName"] + tool_call_id = tool_use["toolUseId"] + arguments = json.loads(tool_use["content"]) + + # Call tool function + if self.has_function(function_name): + if function_name in self._functions.keys() or None in self._functions.keys(): + await self.call_function( + context=self._context, + tool_call_id=tool_call_id, + function_name=function_name, + arguments=arguments, + ) + else: + raise AWSNovaSonicUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) + + async def _handle_content_end_event(self, event_json): + if not self._content_being_received: # should never happen + return + content = self._content_being_received + + content_end = event_json["contentEnd"] + stop_reason = content_end["stopReason"] + + # Bookkeeping: clear current content being received + self._content_being_received = None + + if content.role == Role.ASSISTANT: + if content.type == ContentType.TEXT: + # Ignore non-final text, and the "interrupted" message (which isn't meaningful text) + if content.text_stage == TextStage.FINAL and stop_reason != "INTERRUPTED": + if self._assistant_is_responding: + # Text added to the ongoing assistant response + await self._report_assistant_response_text_added(content.text_content) + elif content.role == Role.USER: + if content.type == ContentType.TEXT: + if content.text_stage == TextStage.FINAL: + # User transcription text added + await self._report_user_transcription_text_added(content.text_content) + + async def _handle_completion_end_event(self, event_json): + pass + + async def _report_assistant_response_started(self): + logger.debug("Assistant response started") + + # Report that the assistant has started their response. + await self.push_frame(LLMFullResponseStartFrame()) + + # Report that equivalent of TTS (this is a speech-to-speech model) started + await self.push_frame(TTSStartedFrame()) + + async def _report_assistant_response_text_added(self, text): + if not self._context: # should never happen + return + + logger.debug(f"Assistant response text added: {text}") + + # Report some text added to the ongoing assistant response + await self.push_frame(LLMTextFrame(text)) + + # Report some text added to the *equivalent* of TTS (this is a speech-to-speech model) + await self.push_frame(TTSTextFrame(text)) + + # TODO: this is a (hopefully temporary) HACK. Here we directly manipulate the context rather + # than relying on the frames pushed to the assistant context aggregator. The pattern of + # receiving full-sentence text after the assistant has spoken does not easily fit with the + # Pipecat expectation of chunks of text streaming in while the assistant is speaking. + # Interruption handling was especially challenging. Rather than spend days trying to fit a + # square peg in a round hole, I decided on this hack for the time being. We can most cleanly + # abandon this hack if/when AWS Nova Sonic implements streaming smaller text chunks + # interspersed with audio. Note that when we move away from this hack, we need to make sure + # that on an interruption we avoid sending LLMFullResponseEndFrame, which gets the + # LLMAssistantContextAggregator into a bad state. + self._context.buffer_assistant_text(text) + + async def _report_assistant_response_ended(self): + if not self._context: # should never happen + return + + logger.debug("Assistant response ended") + + # Report that the assistant has finished their response. + await self.push_frame(LLMFullResponseEndFrame()) + + # Report that equivalent of TTS (this is a speech-to-speech model) stopped. + await self.push_frame(TTSStoppedFrame()) + + # For an explanation of this hack, see _report_assistant_response_text_added. + self._context.flush_aggregated_assistant_text() + + async def _report_user_transcription_text_added(self, text): + if not self._context: # should never happen + return + + logger.debug(f"User transcription text added: {text}") + + # Manually add new user transcription text to context. + # We can't rely on the user context aggregator to do this since it's upstream from the LLM. + self._context.add_user_transcription_text(text) + + # Report that some new user transcription text is available. + if self._send_transcription_frames: + await self.push_frame( + TranscriptionFrame(text=text, user_id="", timestamp=time_now_iso8601()) + ) + + # + # context + # + + def create_context_aggregator( + self, + context: OpenAILLMContext, + *, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), + ) -> AWSNovaSonicContextAggregatorPair: + context.set_llm_adapter(self.get_llm_adapter()) + + user = AWSNovaSonicUserContextAggregator(context=context, params=user_params) + assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params) + + return AWSNovaSonicContextAggregatorPair(user, assistant) + + # + # assistant response trigger (HACK) + # + + # Class variable + AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION = ( + "Start speaking when you hear the user say 'ready', but don't consider that 'ready' to be " + "a meaningful part of the conversation other than as a trigger for you to start speaking." + ) + + async def trigger_assistant_response(self): + if self._triggering_assistant_response: + return False + + self._triggering_assistant_response = True + + # Read audio bytes, if we don't already have them cached + if not self._assistant_response_trigger_audio: + file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav") + with wave.open(file_path.open("rb"), "rb") as wav_file: + self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes()) + + # Send the trigger audio, if we're fully connected and set up + if self._connected_time is not None: + await self._send_assistant_response_trigger() + self._triggering_assistant_response = False + + async def _send_assistant_response_trigger(self): + if ( + not self._assistant_response_trigger_audio or self._connected_time is None + ): # should never happen + return + + logger.debug("Sending assistant response trigger...") + + chunk_duration = 0.02 # what we might get from InputAudioRawFrame + chunk_size = int( + chunk_duration + * self._params.input_sample_rate + * self._params.input_channel_count + * (self._params.input_sample_size / 8) + ) # e.g. 0.02 seconds of 16-bit (2-byte) PCM mono audio at 16kHz is 640 bytes + + # Lead with a bit of blank audio, if needed. + # It seems like the LLM can't quite "hear" the first little bit of audio sent on a + # connection. + current_time = time.time() + max_blank_audio_duration = 0.5 + blank_audio_duration = ( + max_blank_audio_duration - (current_time - self._connected_time) + if self._connected_time is not None + and (current_time - self._connected_time) < max_blank_audio_duration + else None + ) + if blank_audio_duration: + logger.debug( + f"Leading assistant response trigger with {blank_audio_duration}s of blank audio" + ) + blank_audio_chunk = b"\x00" * chunk_size + num_chunks = int(blank_audio_duration / chunk_duration) + for _ in range(num_chunks): + await self._send_user_audio_event(blank_audio_chunk) + await asyncio.sleep(chunk_duration) + + # Send trigger audio + # NOTE: this audio *will* be transcribed and eventually make it into the context. That's OK: + # if we ever need to seed this service again with context it would make sense to include it + # since the instruction (i.e. the "wait for the trigger" instruction) will be part of the + # context as well. + audio_chunks = [ + self._assistant_response_trigger_audio[i : i + chunk_size] + for i in range(0, len(self._assistant_response_trigger_audio), chunk_size) + ] + for chunk in audio_chunks: + await self._send_user_audio_event(chunk) + await asyncio.sleep(chunk_duration) diff --git a/src/pipecat/services/aws_nova_sonic/context.py b/src/pipecat/services/aws_nova_sonic/context.py new file mode 100644 index 000000000..561ae53db --- /dev/null +++ b/src/pipecat/services/aws_nova_sonic/context.py @@ -0,0 +1,217 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import copy +from dataclasses import dataclass, field +from enum import Enum + +from loguru import logger + +from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + DataFrame, + Frame, + FunctionCallResultFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesAppendFrame, + LLMMessagesUpdateFrame, + LLMSetToolChoiceFrame, + LLMSetToolsFrame, + StartInterruptionFrame, + TextFrame, + UserImageRawFrame, +) +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame +from pipecat.services.openai.llm import ( + OpenAIAssistantContextAggregator, + OpenAIUserContextAggregator, +) + + +class Role(Enum): + SYSTEM = "SYSTEM" + USER = "USER" + ASSISTANT = "ASSISTANT" + TOOL = "TOOL" + + +@dataclass +class AWSNovaSonicConversationHistoryMessage: + role: Role # only USER and ASSISTANT + text: str + + +@dataclass +class AWSNovaSonicConversationHistory: + system_instruction: str = None + messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list) + + +class AWSNovaSonicLLMContext(OpenAILLMContext): + def __init__(self, messages=None, tools=None, **kwargs): + super().__init__(messages=messages, tools=tools, **kwargs) + self.__setup_local() + + def __setup_local(self, system_instruction: str = ""): + self._assistant_text = "" + self._system_instruction = system_instruction + + @staticmethod + def upgrade_to_nova_sonic( + obj: OpenAILLMContext, system_instruction: str + ) -> "AWSNovaSonicLLMContext": + if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext): + obj.__class__ = AWSNovaSonicLLMContext + obj.__setup_local(system_instruction) + return obj + + # NOTE: this method has the side-effect of updating _system_instruction from messages + def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory: + history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction) + + # Bail if there are no messages + if not self.messages: + return history + + messages = copy.deepcopy(self.messages) + + # If we have a "system" message as our first message, let's pull that out into "instruction" + if messages[0].get("role") == "system": + system = messages.pop(0) + content = system.get("content") + if isinstance(content, str): + history.system_instruction = content + elif isinstance(content, list): + history.system_instruction = content[0].get("text") + if history.system_instruction: + self._system_instruction = history.system_instruction + + # Process remaining messages to fill out conversation history. + # Nova Sonic supports "user" and "assistant" messages in history. + for message in messages: + history_message = self.from_standard_message(message) + if history_message: + history.messages.append(history_message) + + return history + + def get_messages_for_persistent_storage(self): + messages = super().get_messages_for_persistent_storage() + # If we have a system instruction and messages doesn't already contain it, add it + if self._system_instruction and not (messages and messages[0].get("role") == "system"): + messages.insert(0, {"role": "system", "content": self._system_instruction}) + return messages + + def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage: + role = message.get("role") + if message.get("role") == "user" or message.get("role") == "assistant": + content = message.get("content") + if isinstance(message.get("content"), list): + content = "" + for c in message.get("content"): + if c.get("type") == "text": + content += " " + c.get("text") + else: + logger.error( + f"Unhandled content type in context message: {c.get('type')} - {message}" + ) + # There won't be content if this is an assistant tool call entry. + # We're ignoring those since they can't be loaded into AWS Nova Sonic conversation + # history + if content: + return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content) + # NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova + # Sonic conversation history + + def add_user_transcription_text(self, text): + message = { + "role": "user", + "content": [{"type": "text", "text": text}], + } + self.add_message(message) + # logger.debug(f"Context updated (user): {self.get_messages_for_logging()}") + + def buffer_assistant_text(self, text): + self._assistant_text += text + # logger.debug(f"Assistant text buffered: {self._assistant_text}") + + def flush_aggregated_assistant_text(self): + if not self._assistant_text: + return + message = { + "role": "assistant", + "content": [{"type": "text", "text": self._assistant_text}], + } + self._assistant_text = "" + self.add_message(message) + # logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}") + + +@dataclass +class AWSNovaSonicMessagesUpdateFrame(DataFrame): + context: AWSNovaSonicLLMContext + + +class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator): + async def process_frame( + self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM + ): + await super().process_frame(frame, direction) + + # Parent does not push LLMMessagesUpdateFrame + if isinstance(frame, LLMMessagesUpdateFrame): + await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context)) + + +class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator): + async def process_frame(self, frame: Frame, direction: FrameDirection): + # HACK: For now, disable the context aggregator by making it just pass through all frames + # that the parent handles (except the function call stuff, which we still need). + # For an explanation of this hack, see + # AWSNovaSonicLLMService._report_assistant_response_text_added. + if isinstance( + frame, + ( + StartInterruptionFrame, + LLMFullResponseStartFrame, + LLMFullResponseEndFrame, + TextFrame, + LLMMessagesAppendFrame, + LLMMessagesUpdateFrame, + LLMSetToolsFrame, + LLMSetToolChoiceFrame, + UserImageRawFrame, + BotStoppedSpeakingFrame, + ), + ): + await self.push_frame(frame, direction) + else: + await super().process_frame(frame, direction) + + async def handle_function_call_result(self, frame: FunctionCallResultFrame): + await super().handle_function_call_result(frame) + + # The standard function callback code path pushes the FunctionCallResultFrame from the LLM + # itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side + # context. Let's push a special frame to do that. + await self.push_frame( + AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM + ) + + +@dataclass +class AWSNovaSonicContextAggregatorPair: + _user: AWSNovaSonicUserContextAggregator + _assistant: AWSNovaSonicAssistantContextAggregator + + def user(self) -> AWSNovaSonicUserContextAggregator: + return self._user + + def assistant(self) -> AWSNovaSonicAssistantContextAggregator: + return self._assistant diff --git a/src/pipecat/services/aws_nova_sonic/frames.py b/src/pipecat/services/aws_nova_sonic/frames.py new file mode 100644 index 000000000..94d410f22 --- /dev/null +++ b/src/pipecat/services/aws_nova_sonic/frames.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dataclasses import dataclass + +from pipecat.frames.frames import DataFrame, FunctionCallResultFrame + + +@dataclass +class AWSNovaSonicFunctionCallResultFrame(DataFrame): + result_frame: FunctionCallResultFrame diff --git a/src/pipecat/services/aws_nova_sonic/ready.wav b/src/pipecat/services/aws_nova_sonic/ready.wav new file mode 100644 index 000000000..ca932afa6 Binary files /dev/null and b/src/pipecat/services/aws_nova_sonic/ready.wav differ diff --git a/src/pipecat/services/canonical/__init__.py b/src/pipecat/services/canonical/__init__.py deleted file mode 100644 index f47b99c4e..000000000 --- a/src/pipecat/services/canonical/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import sys - -from pipecat.services import DeprecatedModuleProxy - -from .metrics import * - -sys.modules[__name__] = DeprecatedModuleProxy(globals(), "canonical", "canonical.metrics") diff --git a/src/pipecat/services/canonical/metrics.py b/src/pipecat/services/canonical/metrics.py deleted file mode 100644 index 012cd4ab7..000000000 --- a/src/pipecat/services/canonical/metrics.py +++ /dev/null @@ -1,230 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import io -import os -import uuid -import wave -from datetime import datetime -from typing import Dict, List, Optional, Tuple - -import aiohttp -from loguru import logger - -from pipecat.frames.frames import CancelFrame, EndFrame, Frame -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_service import AIService - -try: - import aiofiles - import aiofiles.os -except ModuleNotFoundError as e: - logger.error(f"Exception: {e}") - logger.error( - "In order to use Canonical Metrics, you need to `pip install pipecat-ai[canonical]`. " - + "Also, set the `CANONICAL_API_KEY` environment variable." - ) - raise Exception(f"Missing module: {e}") - - -# Multipart upload part size in bytes, cannot be smaller than 5MB -PART_SIZE = 1024 * 1024 * 5 - - -class CanonicalMetricsService(AIService): - """Initialize a CanonicalAudioProcessor instance. - - This class uses an AudioBufferProcessor to get the conversation audio and - uploads it to Canonical Voice API for audio processing. - - Args: - call_id (str): Your unique identifier for the call. This is used to match the call in the Canonical Voice system to the call in your system. - assistant (str): Identifier for the AI assistant. This can be whatever you want, it's intended for you convenience so you can distinguish - between different assistants and a grouping mechanism for calls. - assistant_speaks_first (bool, optional): Indicates if the assistant speaks first in the conversation. Defaults to True. - output_dir (str, optional): Directory to save temporary audio files. Defaults to "recordings". - - Attributes: - call_id (str): Stores the unique call identifier. - assistant (str): Stores the assistant identifier. - assistant_speaks_first (bool): Indicates whether the assistant speaks first. - output_dir (str): Directory path for saving temporary audio files. - - The constructor also ensures that the output directory exists. - """ - - def __init__( - self, - *, - aiohttp_session: aiohttp.ClientSession, - call_id: str, - assistant: str, - api_key: str, - api_url: str = "https://voiceapp.canonical.chat/api/v1", - assistant_speaks_first: bool = True, - output_dir: str = "recordings", - audio_buffer_processor: Optional[AudioBufferProcessor] = None, - context: Optional[OpenAILLMContext] = None, - **kwargs, - ): - super().__init__(**kwargs) - # Validate that at least one of audio_buffer_processor or context is provided - if audio_buffer_processor is None and context is None: - raise ValueError("At least one of audio_buffer_processor or context must be specified") - - self._aiohttp_session = aiohttp_session - self._audio_buffer_processor = audio_buffer_processor - self._api_key = api_key - self._api_url = api_url - self._call_id = call_id - self._assistant = assistant - self._assistant_speaks_first = assistant_speaks_first - self._output_dir = output_dir - self._context = context - - async def stop(self, frame: EndFrame): - await super().stop(frame) - await self._process_completion() - - async def cancel(self, frame: CancelFrame): - await super().cancel(frame) - await self._process_completion() - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - await self.push_frame(frame, direction) - - async def _process_completion(self): - if self._audio_buffer_processor is not None: - await self._process_audio() - elif self._context is not None: - await self._process_transcript() - - async def _process_transcript(self): - params = { - "callId": self._call_id, - "assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first}, - "transcript": self._context.messages, - } - response = await self._aiohttp_session.post( - f"{self._api_url}/call", - headers=self._request_headers(), - json=params, - ) - if not response.ok: - logger.error(f"Failed to process transcript: {await response.text()}") - - async def _process_audio(self): - audio_buffer_processor = self._audio_buffer_processor - - if not audio_buffer_processor.has_audio(): - return - - os.makedirs(self._output_dir, exist_ok=True) - filename = self._get_output_filename() - audio = audio_buffer_processor.merge_audio_buffers() - - with io.BytesIO() as buffer: - with wave.open(buffer, "wb") as wf: - wf.setsampwidth(2) - wf.setnchannels(audio_buffer_processor.num_channels) - wf.setframerate(audio_buffer_processor.sample_rate) - wf.writeframes(audio) - async with aiofiles.open(filename, "wb") as file: - await file.write(buffer.getvalue()) - - try: - await self._multipart_upload(filename) - await aiofiles.os.remove(filename) - except FileNotFoundError: - pass - except Exception as e: - logger.error(f"Failed to upload recording: {e}") - - def _get_output_filename(self): - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - return f"{self._output_dir}/{timestamp}-{uuid.uuid4().hex}.wav" - - def _request_headers(self): - return {"Content-Type": "application/json", "X-Canonical-Api-Key": self._api_key} - - async def _multipart_upload(self, file_path: str): - upload_request, upload_response = await self._request_upload(file_path) - if upload_request is None or upload_response is None: - return - parts = await self._upload_parts(file_path, upload_response) - if parts is None: - return - await self._upload_complete(parts, upload_request, upload_response) - - async def _request_upload(self, file_path: str) -> Tuple[Dict, Dict]: - filename = os.path.basename(file_path) - filesize = os.path.getsize(file_path) - numparts = int((filesize + PART_SIZE - 1) / PART_SIZE) - - params = { - "filename": filename, - "parts": numparts, - "callId": self._call_id, - "assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first}, - } - logger.debug(f"Requesting presigned URLs for {numparts} parts") - response = await self._aiohttp_session.post( - f"{self._api_url}/recording/uploadRequest", headers=self._request_headers(), json=params - ) - if not response.ok: - logger.error(f"Failed to get presigned URLs: {await response.text()}") - return None, None - response_json = await response.json() - return params, response_json - - async def _upload_parts(self, file_path: str, upload_response: Dict) -> List[Dict]: - urls = upload_response["urls"] - parts = [] - try: - async with aiofiles.open(file_path, "rb") as file: - for partnum, upload_url in enumerate(urls, start=1): - data = await file.read(PART_SIZE) - if not data: - break - - response = await self._aiohttp_session.put(upload_url, data=data) - if not response.ok: - logger.error(f"Failed to upload part {partnum}: {await response.text()}") - return None - - etag = response.headers["ETag"] - parts.append({"partnum": str(partnum), "etag": etag}) - - except Exception as e: - logger.error(f"Multipart upload aborted, an error occurred: {str(e)}") - return parts - - async def _upload_complete( - self, parts: List[Dict], upload_request: Dict, upload_response: Dict - ): - params = { - "filename": upload_request["filename"], - "parts": parts, - "slug": upload_response["slug"], - "callId": self._call_id, - "assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first}, - } - if self._context is not None: - params["transcript"] = self._context.messages - - logger.debug(f"Completing upload for {params['filename']}") - logger.debug(f"Slug: {params['slug']}") - response = await self._aiohttp_session.post( - f"{self._api_url}/recording/uploadComplete", - headers=self._request_headers(), - json=params, - ) - if not response.ok: - logger.error(f"Failed to complete upload: {await response.text()}") - return diff --git a/src/pipecat/services/deepgram/tts.py b/src/pipecat/services/deepgram/tts.py index ec8a755a0..93c710f9e 100644 --- a/src/pipecat/services/deepgram/tts.py +++ b/src/pipecat/services/deepgram/tts.py @@ -30,7 +30,7 @@ class DeepgramTTSService(TTSService): self, *, api_key: str, - voice: str = "aura-helios-en", + voice: str = "aura-2-helena-en", base_url: str = "", sample_rate: Optional[int] = None, encoding: str = "linear16", diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index 4362fcdc9..324e8099e 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -7,11 +7,12 @@ import asyncio import base64 import json +import uuid from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union import aiohttp from loguru import logger -from pydantic import BaseModel, model_validator +from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, @@ -26,7 +27,10 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import InterruptibleWordTTSService, WordTTSService +from pipecat.services.tts_service import ( + AudioContextWordTTSService, + WordTTSService, +) from pipecat.transcriptions.language import Language # See .env.example for ElevenLabs configuration needed @@ -159,26 +163,17 @@ def calculate_word_times( return word_times -class ElevenLabsTTSService(InterruptibleWordTTSService): +class ElevenLabsTTSService(AudioContextWordTTSService): class InputParams(BaseModel): language: Optional[Language] = None - optimize_streaming_latency: Optional[str] = None stability: Optional[float] = None similarity_boost: Optional[float] = None style: Optional[float] = None use_speaker_boost: Optional[bool] = None speed: Optional[float] = None auto_mode: Optional[bool] = True - - @model_validator(mode="after") - def validate_voice_settings(self): - stability = self.stability - similarity_boost = self.similarity_boost - if (stability is None) != (similarity_boost is None): - raise ValueError( - "Both 'stability' and 'similarity_boost' must be provided when using voice settings" - ) - return self + enable_ssml_parsing: Optional[bool] = None + enable_logging: Optional[bool] = None def __init__( self, @@ -220,13 +215,14 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): "language": self.language_to_service_language(params.language) if params.language else None, - "optimize_streaming_latency": params.optimize_streaming_latency, "stability": params.stability, "similarity_boost": params.similarity_boost, "style": params.style, "use_speaker_boost": params.use_speaker_boost, "speed": params.speed, "auto_mode": str(params.auto_mode).lower(), + "enable_ssml_parsing": params.enable_ssml_parsing, + "enable_logging": params.enable_logging, } self.set_model_name(model) self.set_voice(voice_id) @@ -238,6 +234,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): self._started = False self._cumulative_time = 0 + # Context management for v1 multi API + self._context_id = None self._receive_task = None self._keepalive_task = None @@ -253,15 +251,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): async def set_model(self, model: str): await super().set_model(model) logger.info(f"Switching TTS model to: [{model}]") - await self._disconnect() - await self._connect() + # No need to disconnect/reconnect for model changes with multi-context API async def _update_settings(self, settings: Mapping[str, Any]): prev_voice = self._voice_id await super()._update_settings(settings) + # If voice changes, we don't need to reconnect, just use a new context if not prev_voice == self._voice_id: - await self._disconnect() - await self._connect() logger.info(f"Switching TTS voice to: [{self._voice_id}]") async def start(self, frame: StartFrame): @@ -278,8 +274,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): await self._disconnect() async def flush_audio(self): - if self._websocket: - msg = {"text": " ", "flush": True} + if self._websocket and self._context_id: + msg = {"context_id": self._context_id, "flush": True} await self._websocket.send(json.dumps(msg)) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): @@ -319,10 +315,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): voice_id = self._voice_id model = self.model_name output_format = self._output_format - url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}" + url = f"{self._url}/v1/text-to-speech/{voice_id}/multi-stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}" - if self._settings["optimize_streaming_latency"]: - url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}" + if self._settings["enable_ssml_parsing"]: + url += f"&enable_ssml_parsing={self._settings['enable_ssml_parsing']}" + + if self._settings["enable_logging"]: + url += f"&enable_logging={self._settings['enable_logging']}" # Language can only be used with the ELEVENLABS_MULTILINGUAL_MODELS language = self._settings["language"] @@ -337,14 +336,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): # Set max websocket message size to 16MB for large audio responses self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024) - # According to ElevenLabs, we should always start with a single space. - msg: Dict[str, Any] = { - "text": " ", - "xi_api_key": self._api_key, - } - if self._voice_settings: - msg["voice_settings"] = self._voice_settings - await self._websocket.send(json.dumps(msg)) except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -356,12 +347,15 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): if self._websocket: logger.debug("Disconnecting from ElevenLabs") - await self._websocket.send(json.dumps({"text": ""})) + # Close all contexts and the socket + if self._context_id: + await self._websocket.send(json.dumps({"close_socket": True})) await self._websocket.close() except Exception as e: logger.error(f"{self} error closing websocket: {e}") finally: self._started = False + self._context_id = None self._websocket = None def _get_websocket(self): @@ -369,9 +363,35 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): return self._websocket raise Exception("Websocket not connected") + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + + # Close the current context when interrupted without closing the websocket + if self._context_id and self._websocket: + logger.trace(f"Closing context {self._context_id} due to interruption") + try: + await self._websocket.send( + json.dumps({"context_id": self._context_id, "close_context": True}) + ) + except Exception as e: + logger.error(f"Error closing context on interruption: {e}") + self._context_id = None + self._started = False + async def _receive_messages(self): async for message in self._get_websocket(): msg = json.loads(message) + # Check if this message belongs to the current context + # The default context may return null/None for context_id + received_ctx_id = msg.get("context_id") + if ( + self._context_id is not None + and received_ctx_id is not None + and received_ctx_id != self._context_id + ): + logger.trace(f"Ignoring message from different context: {received_ctx_id}") + continue + if msg.get("audio"): await self.stop_ttfb_metrics() self.start_word_timestamps() @@ -383,20 +403,45 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): word_times = calculate_word_times(msg["alignment"], self._cumulative_time) await self.add_word_timestamps(word_times) self._cumulative_time = word_times[-1][1] + if msg.get("is_final"): + logger.trace(f"Received final message for context {received_ctx_id}") + # Context has finished + if self._context_id == received_ctx_id: + self._context_id = None + self._started = False async def _keepalive_task_handler(self): while True: await asyncio.sleep(10) try: - await self._send_text("") + # Send an empty message to keep the connection alive + if self._websocket and self._websocket.open: + await self._websocket.send(json.dumps({})) except websockets.ConnectionClosed as e: logger.warning(f"{self} keepalive error: {e}") break async def _send_text(self, text: str): if self._websocket: - msg = {"text": text + " "} - await self._websocket.send(json.dumps(msg)) + if not self._context_id: + # First message for a new context - need a space to initialize + msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key} + + # Add voice settings only in first message for a context + if self._voice_settings: + msg["voice_settings"] = self._voice_settings + + await self._websocket.send(json.dumps(msg)) + self._context_id = msg["context_id"] + logger.trace(f"Created new context {self._context_id}") + + # Now send the actual text content + msg = {"text": text, "context_id": self._context_id} + await self._websocket.send(json.dumps(msg)) + else: + # Continuing with an existing context + msg = {"text": text, "context_id": self._context_id} + await self._websocket.send(json.dumps(msg)) async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"{self}: Generating TTS [{text}]") @@ -406,6 +451,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): await self._connect() try: + # Close previous context if there was one + if self._context_id and not self._started: + await self._websocket.send( + json.dumps({"context_id": self._context_id, "close_context": True}) + ) + self._context_id = None + if not self._started: await self.start_ttfb_metrics() yield TTSStartedFrame() @@ -417,8 +469,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): except Exception as e: logger.error(f"{self} error sending message: {e}") yield TTSStoppedFrame() - await self._disconnect() - await self._connect() + self._started = False + self._context_id = None return yield None except Exception as e: diff --git a/src/pipecat/services/google/rtvi.py b/src/pipecat/services/google/rtvi.py index 88e67e6c6..cd60f6f1f 100644 --- a/src/pipecat/services/google/rtvi.py +++ b/src/pipecat/services/google/rtvi.py @@ -9,8 +9,9 @@ from typing import List, Literal, Optional from pydantic import BaseModel from pipecat.frames.frames import Frame +from pipecat.observers.base_observer import FramePushed from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.processors.frameworks.rtvi import RTVIObserver +from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame @@ -27,18 +28,13 @@ class RTVIBotLLMSearchResponseMessage(BaseModel): class GoogleRTVIObserver(RTVIObserver): - def __init__(self, rtvi: FrameProcessor): + def __init__(self, rtvi: RTVIProcessor): super().__init__(rtvi) - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): - await super().on_push_frame(src, dst, frame, direction, timestamp) + async def on_push_frame(self, data: FramePushed): + await super().on_push_frame(data) + + frame = data.frame if isinstance(frame, LLMSearchResponseFrame): await self._handle_llm_search_response_frame(frame) diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 15b2bd6e5..21b62325d 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -190,6 +190,7 @@ class LLMService(AIService): function_name: Optional[str] = None, tool_call_id: Optional[str] = None, text_content: Optional[str] = None, + video_source: Optional[str] = None, ): await self.push_frame( UserImageRequestFrame( @@ -197,6 +198,7 @@ class LLMService(AIService): function_name=function_name, tool_call_id=tool_call_id, context=text_content, + video_source=video_source, ), FrameDirection.UPSTREAM, ) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index 334ce98c8..0c37f73ce 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -577,15 +577,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): arguments = json.loads(item.arguments) if self.has_function(function_name): run_llm = index == total_items - 1 - if function_name in self._functions.keys(): - await self.call_function( - context=self._context, - tool_call_id=tool_id, - function_name=function_name, - arguments=arguments, - run_llm=run_llm, - ) - elif None in self._functions.keys(): + if function_name in self._functions.keys() or None in self._functions.keys(): await self.call_function( context=self._context, tool_call_id=tool_id, diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index e2368ba09..b5dfc5de1 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -15,7 +15,7 @@ from pipecat.frames.frames import ( StartFrame, SystemFrame, ) -from pipecat.observers.base_observer import BaseObserver +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask @@ -42,14 +42,10 @@ class HeartbeatsObserver(BaseObserver): self._target = target self._callback = heartbeat_callback - async def on_push_frame( - self, - src: FrameProcessor, - dst: FrameProcessor, - frame: Frame, - direction: FrameDirection, - timestamp: int, - ): + async def on_push_frame(self, data: FramePushed): + src = data.source + frame = data.frame + if src == self._target and isinstance(frame, HeartbeatFrame): await self._callback(self._target, frame) diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 51ebdb677..f9a27a6d3 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -122,6 +122,7 @@ class BaseInputTransport(FrameProcessor): # Configure VAD analyzer. if self._params.vad_analyzer: self._params.vad_analyzer.set_sample_rate(self._sample_rate) + # Configure End of turn analyzer. if self._params.turn_analyzer: self._params.turn_analyzer.set_sample_rate(self._sample_rate) @@ -129,10 +130,6 @@ class BaseInputTransport(FrameProcessor): # Start audio filter. if self._params.audio_in_filter: await self._params.audio_in_filter.start(self._sample_rate) - # Create audio input queue and task if needed. - if not self._audio_task and self._params.audio_in_enabled: - self._audio_in_queue = asyncio.Queue() - self._audio_task = self.create_task(self._audio_task_handler()) async def stop(self, frame: EndFrame): # Cancel and wait for the audio input task to finish. @@ -149,6 +146,13 @@ class BaseInputTransport(FrameProcessor): await self.cancel_task(self._audio_task) self._audio_task = None + async def set_transport_ready(self, frame: StartFrame): + """To be called when the transport is ready to stream.""" + # Create audio input queue and task if needed. + if not self._audio_task and self._params.audio_in_enabled: + self._audio_in_queue = asyncio.Queue() + self._audio_task = self.create_task(self._audio_task_handler()) + async def push_audio_frame(self, frame: InputAudioRawFrame): if self._params.audio_in_enabled: await self._audio_in_queue.put(frame) diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index fa5d5e1c4..81492b84d 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -78,6 +78,16 @@ class BaseOutputTransport(FrameProcessor): audio_bytes_10ms = int(self._sample_rate / 100) * self._params.audio_out_channels * 2 self._audio_chunk_size = audio_bytes_10ms * self._params.audio_out_10ms_chunks + async def stop(self, frame: EndFrame): + for _, sender in self._media_senders.items(): + await sender.stop(frame) + + async def cancel(self, frame: CancelFrame): + for _, sender in self._media_senders.items(): + await sender.cancel(frame) + + async def set_transport_ready(self, frame: StartFrame): + """To be called when the transport is ready to stream.""" # Register destinations. for destination in self._params.audio_out_destinations: await self.register_audio_destination(destination) @@ -112,14 +122,6 @@ class BaseOutputTransport(FrameProcessor): ) await self._media_senders[destination].start(frame) - async def stop(self, frame: EndFrame): - for _, sender in self._media_senders.items(): - await sender.stop(frame) - - async def cancel(self, frame: CancelFrame): - for _, sender in self._media_senders.items(): - await sender.cancel(frame) - async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): pass diff --git a/src/pipecat/transports/local/audio.py b/src/pipecat/transports/local/audio.py index 8bfd7ee34..ba554c9e3 100644 --- a/src/pipecat/transports/local/audio.py +++ b/src/pipecat/transports/local/audio.py @@ -61,6 +61,8 @@ class LocalAudioInputTransport(BaseInputTransport): ) self._in_stream.start_stream() + await self.set_transport_ready(frame) + async def cleanup(self): await super().cleanup() if self._in_stream: @@ -111,6 +113,8 @@ class LocalAudioOutputTransport(BaseOutputTransport): ) self._out_stream.start_stream() + await self.set_transport_ready(frame) + async def cleanup(self): await super().cleanup() if self._out_stream: diff --git a/src/pipecat/transports/local/tk.py b/src/pipecat/transports/local/tk.py index bed6371c2..4086497cb 100644 --- a/src/pipecat/transports/local/tk.py +++ b/src/pipecat/transports/local/tk.py @@ -68,6 +68,8 @@ class TkInputTransport(BaseInputTransport): ) self._in_stream.start_stream() + await self.set_transport_ready(frame) + async def cleanup(self): await super().cleanup() if self._in_stream: @@ -124,6 +126,8 @@ class TkOutputTransport(BaseOutputTransport): ) self._out_stream.start_stream() + await self.set_transport_ready(frame) + async def cleanup(self): await super().cleanup() if self._out_stream: diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index 4a20bc49b..f04d56b0d 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -131,6 +131,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): await self._client.trigger_client_connected() if not self._receive_task: self._receive_task = self.create_task(self._receive_messages()) + await self.set_transport_ready(frame) async def _stop_tasks(self): if self._monitor_websocket_task: @@ -204,6 +205,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): await self._client.setup(frame) await self._params.serializer.setup(frame) self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2 + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): await super().stop(frame) diff --git a/src/pipecat/transports/network/small_webrtc.py b/src/pipecat/transports/network/small_webrtc.py index fdd501299..ffa3f441a 100644 --- a/src/pipecat/transports/network/small_webrtc.py +++ b/src/pipecat/transports/network/small_webrtc.py @@ -395,6 +395,7 @@ class SmallWebRTCInputTransport(BaseInputTransport): self._receive_audio_task = self.create_task(self._receive_audio()) if not self._receive_video_task and self._params.video_in_enabled: self._receive_video_task = self.create_task(self._receive_video()) + await self.set_transport_ready(frame) async def _stop_tasks(self): if self._receive_audio_task: @@ -487,6 +488,7 @@ class SmallWebRTCOutputTransport(BaseOutputTransport): await super().start(frame) await self._client.setup(self._params, frame) await self._client.connect() + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): await super().stop(frame) diff --git a/src/pipecat/transports/network/websocket_client.py b/src/pipecat/transports/network/websocket_client.py index 7e9725a76..535a0ab21 100644 --- a/src/pipecat/transports/network/websocket_client.py +++ b/src/pipecat/transports/network/websocket_client.py @@ -136,6 +136,7 @@ class WebsocketClientInputTransport(BaseInputTransport): await self._params.serializer.setup(frame) await self._session.setup(frame) await self._session.connect() + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): await super().stop(frame) @@ -186,6 +187,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport): await self._params.serializer.setup(frame) await self._session.setup(frame) await self._session.connect() + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): await super().stop(frame) diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py index b930f9fd6..7c8738871 100644 --- a/src/pipecat/transports/network/websocket_server.py +++ b/src/pipecat/transports/network/websocket_server.py @@ -83,6 +83,7 @@ class WebsocketServerInputTransport(BaseInputTransport): await self._params.serializer.setup(frame) if not self._server_task: self._server_task = self.create_task(self._server_task_handler()) + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): await super().stop(frame) @@ -195,6 +196,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport): await super().start(frame) await self._params.serializer.setup(frame) self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2 + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): await super().stop(frame) diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 3e43ddee1..f1a514d0e 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -175,6 +175,7 @@ class DailyCallbacks(BaseModel): """Callback handlers for Daily events. Attributes: + on_active_speaker_changed: Called when the active speaker of the call has changed. on_joined: Called when bot successfully joined a room. on_left: Called when bot left a room. on_error: Called when an error occurs. @@ -201,6 +202,7 @@ class DailyCallbacks(BaseModel): on_recording_error: Called when recording encounters an error. """ + on_active_speaker_changed: Callable[[Mapping[str, Any]], Awaitable[None]] on_joined: Callable[[Mapping[str, Any]], Awaitable[None]] on_left: Callable[[], Awaitable[None]] on_error: Callable[[str], Awaitable[None]] @@ -698,7 +700,7 @@ class DailyTransportClient(EventHandler): await self.update_subscriptions(participant_settings={participant_id: media}) - self._audio_renderers[participant_id] = {audio_source: callback} + self._audio_renderers.setdefault(participant_id, {})[audio_source] = callback self._client.set_audio_renderer( participant_id, @@ -722,7 +724,7 @@ class DailyTransportClient(EventHandler): await self.update_subscriptions(participant_settings={participant_id: media}) - self._video_renderers[participant_id] = {video_source: callback} + self._video_renderers.setdefault(participant_id, {})[video_source] = callback self._client.set_video_renderer( participant_id, @@ -789,6 +791,9 @@ class DailyTransportClient(EventHandler): # Daily (EventHandler) # + def on_active_speaker_changed(self, participant): + self._call_async_callback(self._callbacks.on_active_speaker_changed, participant) + def on_app_message(self, message: Any, sender: str): self._call_async_callback(self._callbacks.on_app_message, message, sender) @@ -944,19 +949,23 @@ class DailyInputTransport(BaseInputTransport): self._audio_in_task = self.create_task(self._audio_in_task_handler()) async def start(self, frame: StartFrame): - # Setup client. - await self._client.setup(frame) - - # Parent start. - await super().start(frame) - if self._initialized: return self._initialized = True + # Parent start. + await super().start(frame) + + # Setup client. + await self._client.setup(frame) + # Join the room. await self._client.join() + + # Indicate the transport that we are connected. + await self.set_transport_ready(frame) + if self._params.audio_in_stream_on_start: self.start_audio_in_streaming() @@ -1052,12 +1061,13 @@ class DailyInputTransport(BaseInputTransport): video_source: str = "camera", color_format: str = "RGB", ): - self._video_renderers[participant_id] = { - video_source: { - "framerate": framerate, - "timestamp": 0, - "render_next_frame": [], - } + if participant_id not in self._video_renderers: + self._video_renderers[participant_id] = {} + + self._video_renderers[participant_id][video_source] = { + "framerate": framerate, + "timestamp": 0, + "render_next_frame": [], } await self._client.capture_participant_video( @@ -1066,7 +1076,8 @@ class DailyInputTransport(BaseInputTransport): async def request_participant_image(self, frame: UserImageRequestFrame): if frame.user_id in self._video_renderers: - self._video_renderers[frame.user_id]["render_next_frame"].append(frame) + video_source = frame.video_source if frame.video_source else "camera" + self._video_renderers[frame.user_id][video_source]["render_next_frame"].append(frame) async def _on_participant_video_frame( self, participant_id: str, video_frame: VideoFrame, video_source: str @@ -1125,20 +1136,23 @@ class DailyOutputTransport(BaseOutputTransport): self._initialized = False async def start(self, frame: StartFrame): - # Setup client. - await self._client.setup(frame) - - # Parent start. - await super().start(frame) - if self._initialized: return self._initialized = True + # Parent start. + await super().start(frame) + + # Setup client. + await self._client.setup(frame) + # Join the room. await self._client.join() + # Indicate the transport that we are connected. + await self.set_transport_ready(frame) + async def stop(self, frame: EndFrame): # Parent stop. await super().stop(frame) @@ -1201,6 +1215,7 @@ class DailyTransport(BaseTransport): super().__init__(input_name=input_name, output_name=output_name) callbacks = DailyCallbacks( + on_active_speaker_changed=self._on_active_speaker_changed, on_joined=self._on_joined, on_left=self._on_left, on_error=self._on_error, @@ -1236,6 +1251,7 @@ class DailyTransport(BaseTransport): # Register supported handlers. The user will only be able to register # these handlers. + self._register_event_handler("on_active_speaker_changed") self._register_event_handler("on_joined") self._register_event_handler("on_left") self._register_event_handler("on_error") @@ -1370,6 +1386,9 @@ class DailyTransport(BaseTransport): async def update_remote_participants(self, remote_participants: Mapping[str, Any]): await self._client.update_remote_participants(remote_participants=remote_participants) + async def _on_active_speaker_changed(self, participant: Any): + await self._call_event_handler("on_active_speaker_changed", participant) + async def _on_joined(self, data): await self._call_event_handler("on_joined", data) diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index 456a70ea6..36cc5d604 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -370,6 +370,7 @@ class LiveKitInputTransport(BaseInputTransport): await self._client.connect() if not self._audio_in_task and self._params.audio_in_enabled: self._audio_in_task = self.create_task(self._audio_in_task_handler()) + await self.set_transport_ready(frame) logger.info("LiveKitInputTransport started") async def stop(self, frame: EndFrame): @@ -441,6 +442,7 @@ class LiveKitOutputTransport(BaseOutputTransport): await super().start(frame) await self._client.setup(frame) await self._client.connect() + await self.set_transport_ready(frame) logger.info("LiveKitOutputTransport started") async def stop(self, frame: EndFrame): diff --git a/src/pipecat/utils/asyncio.py b/src/pipecat/utils/asyncio.py index acc4acec8..cea447329 100644 --- a/src/pipecat/utils/asyncio.py +++ b/src/pipecat/utils/asyncio.py @@ -6,7 +6,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import Coroutine, Optional, Set +from typing import Coroutine, Dict, Optional, Sequence, Set from loguru import logger @@ -69,14 +69,14 @@ class BaseTaskManager(ABC): pass @abstractmethod - def current_tasks(self) -> Set[asyncio.Task]: + def current_tasks(self) -> Sequence[asyncio.Task]: """Returns the list of currently created/registered tasks.""" pass class TaskManager(BaseTaskManager): def __init__(self) -> None: - self._tasks: Set[asyncio.Task] = set() + self._tasks: Dict[str, asyncio.Task] = {} self._loop: Optional[asyncio.AbstractEventLoop] = None def set_event_loop(self, loop: asyncio.AbstractEventLoop): @@ -179,16 +179,17 @@ class TaskManager(BaseTaskManager): finally: self._remove_task(task) - def current_tasks(self) -> Set[asyncio.Task]: + def current_tasks(self) -> Sequence[asyncio.Task]: """Returns the list of currently created/registered tasks.""" - return self._tasks + return list(self._tasks.values()) def _add_task(self, task: asyncio.Task): - self._tasks.add(task) + name = task.get_name() + self._tasks[name] = task def _remove_task(self, task: asyncio.Task): name = task.get_name() try: - self._tasks.remove(task) + del self._tasks[name] except KeyError as e: logger.trace(f"{name}: unable to remove task (already removed?): {e}") diff --git a/test-requirements.txt b/test-requirements.txt index b34a53ab9..fec8adf52 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1 +1 @@ --e ".[anthropic,google,langchain]" +-e ".[anthropic,aws,google,langchain]" diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index dfe210e07..0f68110ce 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -40,6 +40,11 @@ from pipecat.services.anthropic.llm import ( AnthropicLLMContext, AnthropicUserContextAggregator, ) +from pipecat.services.aws.llm import ( + AWSBedrockAssistantContextAggregator, + AWSBedrockLLMContext, + AWSBedrockUserContextAggregator, +) from pipecat.services.google.llm import ( GoogleAssistantContextAggregator, GoogleLLMContext, @@ -669,26 +674,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola AGGREGATOR_CLASS = LLMUserContextAggregator -# -# OpenAI -# - - -class TestOpenAIUserContextAggregator( - BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = OpenAIUserContextAggregator - - -class TestOpenAIAssistantContextAggregator( - BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = OpenAIAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - # # Anthropic # @@ -724,6 +709,43 @@ class TestAnthropicAssistantContextAggregator( assert context.messages[index]["content"][0]["content"] == json.dumps(content) +# +# AWS (Bedrock) +# + + +class TestAWSBedrockUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AWSBedrockLLMContext + AGGREGATOR_CLASS = AWSBedrockUserContextAggregator + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + +class TestAWSBedrockAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AWSBedrockLLMContext + AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): + assert context.messages[index]["content"][0]["toolResult"]["content"][0][ + "text" + ] == json.dumps(content) + + # # Google # @@ -766,3 +788,23 @@ class TestGoogleAssistantContextAggregator( def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): obj = glm.Content.to_dict(context.messages[index]) assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content) + + +# +# OpenAI +# + + +class TestOpenAIUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIUserContextAggregator + + +class TestOpenAIAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] diff --git a/tests/test_function_calling_adapters.py b/tests/test_function_calling_adapters.py index 5d6dafce3..83640bb80 100644 --- a/tests/test_function_calling_adapters.py +++ b/tests/test_function_calling_adapters.py @@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter +from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter @@ -174,3 +175,32 @@ class TestFunctionAdapters(unittest.TestCase): tools_def = self.tools_def tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]} assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected + + def test_bedrock_adapter(self): + """Test AWS Bedrock adapter format transformation.""" + expected = [ + { + "toolSpec": { + "name": "get_weather", + "description": "Get the weather in a given location", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use.", + }, + "location": { + "type": "string", + "description": "The city, e.g. San Francisco", + }, + }, + "required": ["location", "format"], + } + }, + } + } + ] + assert AWSBedrockLLMAdapter().to_provider_tools_format(self.tools_def) == expected