Compare commits
166 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d02d886f3 | ||
|
|
42289eb30d | ||
|
|
a0c93ab6de | ||
|
|
4bec566bbf | ||
|
|
ec3cd24182 | ||
|
|
e36e64c2e8 | ||
|
|
02a88022dd | ||
|
|
6cae61f2cc | ||
|
|
3b40079120 | ||
|
|
ff0b38859b | ||
|
|
4d499324d1 | ||
|
|
f13e006db2 | ||
|
|
87d9e8c9cd | ||
|
|
4820f1c059 | ||
|
|
860c39d1b1 | ||
|
|
ae5c5ed7f6 | ||
|
|
7aa01c1ca8 | ||
|
|
4d6356748f | ||
|
|
5b1a182421 | ||
|
|
6ac0c34413 | ||
|
|
c115422dbf | ||
|
|
a2a973be27 | ||
|
|
0407744950 | ||
|
|
7ce370ccc6 | ||
|
|
a4867f61aa | ||
|
|
a67a765783 | ||
|
|
81221668b1 | ||
|
|
cc9c264940 | ||
|
|
f2c61ac9fd | ||
|
|
88f8c10f63 | ||
|
|
855f4842dd | ||
|
|
2bf44fe2af | ||
|
|
3e8a7cc254 | ||
|
|
a600c05570 | ||
|
|
3ba6b55659 | ||
|
|
d5f2dcfac0 | ||
|
|
d12134038b | ||
|
|
a22af3a7e0 | ||
|
|
76e07c6c48 | ||
|
|
8d8503bca7 | ||
|
|
a444097060 | ||
|
|
1b9e96c016 | ||
|
|
7967bc53c3 | ||
|
|
6381335346 | ||
|
|
0fd5d26104 | ||
|
|
41f817bf04 | ||
|
|
27115e6565 | ||
|
|
3c4807d7d4 | ||
|
|
8902f1dc94 | ||
|
|
a25333ee51 | ||
|
|
82c7d7ad83 | ||
|
|
ba2ab51ef7 | ||
|
|
22557fa668 | ||
|
|
3fbf59e7c6 | ||
|
|
129ab5ea0e | ||
|
|
dc917523d0 | ||
|
|
5ea7cc9d32 | ||
|
|
e11ede475b | ||
|
|
90d29e04af | ||
|
|
4c67136a8d | ||
|
|
9d78402a33 | ||
|
|
73877218e9 | ||
|
|
6a1be90cbb | ||
|
|
fbac959ecb | ||
|
|
18dd85431c | ||
|
|
abc569b3d2 | ||
|
|
fa5d4ecf86 | ||
|
|
83b0dc39f7 | ||
|
|
0c31b5ef19 | ||
|
|
d16c36c56d | ||
|
|
8fe3bcd484 | ||
|
|
be2858bfbb | ||
|
|
b6b0997553 | ||
|
|
3b751322d3 | ||
|
|
fce6f55ddb | ||
|
|
d9580f72a9 | ||
|
|
cc66ac14f1 | ||
|
|
9ddec0f8b4 | ||
|
|
9babfe9fd9 | ||
|
|
21d8d148b8 | ||
|
|
0588c82bbf | ||
|
|
16e9093d5a | ||
|
|
91a5d580fd | ||
|
|
0473556992 | ||
|
|
fdaa4e476e | ||
|
|
502e7e42a7 | ||
|
|
2ab3d4fb42 | ||
|
|
55014bdd77 | ||
|
|
334796bd65 | ||
|
|
1c25b6fb72 | ||
|
|
91b29de7ca | ||
|
|
21d610cd30 | ||
|
|
f7fe673ad1 | ||
|
|
4b415721e2 | ||
|
|
8d2a98e0e7 | ||
|
|
523e890c8c | ||
|
|
3c748fe772 | ||
|
|
d293cee372 | ||
|
|
8b62a96878 | ||
|
|
0c102ce70b | ||
|
|
3894d2a4b9 | ||
|
|
1f6b61c0db | ||
|
|
8ee28b37cd | ||
|
|
e85e7e4d84 | ||
|
|
1b3afb5511 | ||
|
|
7cec013666 | ||
|
|
86127167fb | ||
|
|
9935a68018 | ||
|
|
5679dde70f | ||
|
|
d81b0f6368 | ||
|
|
9698b008da | ||
|
|
7b05c9283b | ||
|
|
303dd2ec35 | ||
|
|
aa6e81648a | ||
|
|
1a87870ef3 | ||
|
|
aac4ce2d12 | ||
|
|
2a79b2c853 | ||
|
|
15bf5b1533 | ||
|
|
cdc86db8ce | ||
|
|
9d2ad750b5 | ||
|
|
19ceb1a48f | ||
|
|
59217eae38 | ||
|
|
bea0aee835 | ||
|
|
aeace9b9be | ||
|
|
2994640f47 | ||
|
|
10069719e4 | ||
|
|
046b76df60 | ||
|
|
f2d9063984 | ||
|
|
7c1e2793c5 | ||
|
|
99f008e927 | ||
|
|
2699f0c2a6 | ||
|
|
0b6dd98000 | ||
|
|
a14fb20d15 | ||
|
|
728361a6a7 | ||
|
|
106db69e8e | ||
|
|
cf90071926 | ||
|
|
deaeb75a1f | ||
|
|
a666327d70 | ||
|
|
13a0522546 | ||
|
|
7da37a0d1f | ||
|
|
7efb22a323 | ||
|
|
8084e2f909 | ||
|
|
86127c6a6e | ||
|
|
402e019ae2 | ||
|
|
f09e4e238b | ||
|
|
2921162b3b | ||
|
|
ac1582c906 | ||
|
|
e4b01a5844 | ||
|
|
fa663abbbc | ||
|
|
d19e6111c3 | ||
|
|
8a6d504a7e | ||
|
|
43915937f2 | ||
|
|
48e92a22fe | ||
|
|
566af6b0b8 | ||
|
|
12e7613d5f | ||
|
|
04a68f2c57 | ||
|
|
9b4ca12f49 | ||
|
|
453ce715a6 | ||
|
|
d87b6189ba | ||
|
|
8293347b77 | ||
|
|
c85a3f0b94 | ||
|
|
233fb25e6c | ||
|
|
080978daa6 | ||
|
|
62b7c3d3b2 | ||
|
|
066b77fba0 | ||
|
|
607e3040d4 |
13
.github/workflows/build.yaml
vendored
13
.github/workflows/build.yaml
vendored
@@ -21,20 +21,21 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for setuptools_scm
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.10
|
||||
|
||||
|
||||
- name: Install development dependencies
|
||||
run: uv sync --group dev
|
||||
|
||||
|
||||
- name: Build project
|
||||
run: uv build
|
||||
|
||||
|
||||
- name: Install project in editable mode
|
||||
run: uv pip install --editable .
|
||||
run: uv pip install --editable .
|
||||
|
||||
222
CHANGELOG.md
222
CHANGELOG.md
@@ -5,6 +5,214 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added the [Pipecat CLI](https://github.com/pipecat-ai/pipecat-cli) to the
|
||||
required dependencies, enabling you to scaffold a new project directly from
|
||||
`pipecat-ai`. Get started with:
|
||||
|
||||
```bash
|
||||
uv run pipecat init
|
||||
```
|
||||
|
||||
- Expanded support for universal `LLMContext` to `AWSNovaSonicLLMService`.
|
||||
As a reminder, the context-setup pattern when using `LLMContext` is:
|
||||
|
||||
```python
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
```
|
||||
|
||||
(Note that even though `AWSNovaSonicLLMService` now supports the universal
|
||||
`LLMContext`, it is not meant to be swapped out for another LLM service at
|
||||
runtime.)
|
||||
|
||||
Worth noting: whether or not you use the new context-setup pattern with
|
||||
`AWSNovaSonicLLMService`, some types have changed under the hood:
|
||||
|
||||
```python
|
||||
## BEFORE:
|
||||
|
||||
# Context aggregator type
|
||||
context_aggregator: AWSNovaSonicContextAggregatorPair
|
||||
|
||||
# Context frame type
|
||||
frame: OpenAILLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: AWSNovaSonicLLMContext
|
||||
# or
|
||||
context: OpenAILLMContext
|
||||
|
||||
# Reading messages from context
|
||||
messages = context.messages
|
||||
|
||||
## AFTER:
|
||||
|
||||
# Context aggregator type
|
||||
context_aggregator: LLMContextAggregatorPair
|
||||
|
||||
# Context frame type
|
||||
frame: LLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: LLMContext
|
||||
|
||||
# Reading messages from context
|
||||
messages = context.get_messages()
|
||||
```
|
||||
|
||||
- Added support for `bulbul:v3` model in `SarvamTTSService` and
|
||||
`SarvamHttpTTSService`.
|
||||
|
||||
- Added `keyterms_prompt` parameter to `AssemblyAIConnectionParams`.
|
||||
|
||||
- Added `speech_model` parameter to `AssemblyAIConnectionParams` to access the
|
||||
multilingual model.
|
||||
|
||||
- Added support for trickle ICE to the `SmallWebRTCTransport`.
|
||||
|
||||
- Added support for updating `OpenAITTSService` settings (`instructions` and
|
||||
`speed`) at runtime via `TTSUpdateSettingsFrame`.
|
||||
|
||||
- Added `--whatsapp` flag to runner to better surface WhatsApp transport logs.
|
||||
|
||||
- Added `on_connected` and `on_disconnected` events to TTS and STT
|
||||
websocket-based services.
|
||||
|
||||
- Added an `aggregate_sentences` arg in `ElevenLabsHttpTTSService`, where the
|
||||
default value is True.
|
||||
|
||||
- Added a `room_properties` arg to the Daily runner's `configure()` method,
|
||||
allowing `DailyRoomProperties` to be provided.
|
||||
|
||||
- The runner `--folder` argument now supports downloading files from
|
||||
subdirectories.
|
||||
|
||||
### Changed
|
||||
|
||||
- `CartesiaSTTService` now inherits from `WebsocketSTTService`.
|
||||
|
||||
- Package upgrades:
|
||||
|
||||
- `daily-python` upgraded to 0.20.0.
|
||||
- `openai` upgraded to support up to 2.x.x.
|
||||
- `openpipe` upgraded to support up to 5.x.x.
|
||||
|
||||
- `SpeechmaticsSTTService` updated dependencies for `speechmatics-rt>=0.5.0`.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The `send_transcription_frames` argument to `AWSNovaSonicLLMService` is
|
||||
deprecated. Transcription frames are now always sent. They go upstream, to be
|
||||
handled by the user context aggregator. See "Added" section for details.
|
||||
|
||||
- Types in `pipecat.services.aws.nova_sonic.context` have been deprecated due
|
||||
to changes to support `LLMContext`. See "Changed" section for details.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `RivaSegmentedSTTService` where a runtime error occurred due
|
||||
to a mismatch in the \_handle_transcription method's signature.
|
||||
|
||||
- Fixed multiple pipeline task cancellation issues. `asyncio.CancelledError` is
|
||||
now handled properly in `PipelineTask` making it possible to cancel an asyncio
|
||||
task that it's executing a `PipelineRunner` cleanly. Also,
|
||||
`PipelineTask.cancel()` does not block anymore waiting for the `CancelFrame`
|
||||
to reach the end of the pipeline (going back to the behavior in < 0.0.83).
|
||||
|
||||
- Fixed an issue in `ElevenLabsTTSService` and `ElevenLabsHttpTTSService` where
|
||||
the Flash models would split words, resulting in a space being inserted
|
||||
between words.
|
||||
|
||||
- Fixed an issue where audio filters' `stop()` would not be called when using
|
||||
`CancelFrame`.
|
||||
|
||||
- Fixed an issue in `ElevenLabsHttpTTSService`, where
|
||||
`apply_text_normalization` was incorrectly set as a query parameter. It's now
|
||||
being added as a request parameter.
|
||||
|
||||
- Fixed an issue where `RimeHttpTTSService` and `PiperTTSService` could generate
|
||||
incorrectly 16-bit aligned audio frames, potentially leading to internal
|
||||
errors or static audio.
|
||||
|
||||
- Fixed an issue in `SpeechmaticsSTTService` where `AdditionalVocabEntry` items
|
||||
needed to have `sounds_like` for the session to start.
|
||||
|
||||
### Other
|
||||
|
||||
- Added foundational example `47-sentry-metrics.py`, demonstrating how to use the
|
||||
`SentryMetrics` processor.
|
||||
|
||||
- Added foundational example `14x-function-calling-openpipe.py`.
|
||||
|
||||
## [0.0.90] - 2025-10-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added audio filter `KrispVivaFilter` using the Krisp VIVA SDK.
|
||||
|
||||
- Added `--folder` argument to the runner, allowing files saved in that folder
|
||||
to be downloaded from `http://HOST:PORT/file/FILE`.
|
||||
|
||||
- Added `GeminiLiveVertexLLMService`, for accessing Gemini Live via Google
|
||||
Vertex AI.
|
||||
|
||||
- Added some new configuration options to `GeminiLiveLLMService`:
|
||||
|
||||
- `thinking`
|
||||
- `enable_affective_dialog`
|
||||
- `proactivity`
|
||||
|
||||
Note that these new configuration options require using a newer model than
|
||||
the default, like "gemini-2.5-flash-native-audio-preview-09-2025". The last
|
||||
two require specifying `http_options=HttpOptions(api_version="v1alpha")`.
|
||||
|
||||
- Added `on_pipeline_error` event to `PipelineTask`. This event will get fired
|
||||
when an `ErrorFrame` is pushed (use `FrameProcessor.push_error()`).
|
||||
|
||||
```python
|
||||
@task.event_handler("on_pipeline_error")
|
||||
async def on_pipeline_error(task: PipelineTask, frame: ErrorFrame):
|
||||
...
|
||||
```
|
||||
|
||||
- Added a `service_tier` `InputParam` to the `BaseOpenAILLMService`. This
|
||||
parameter can influence the latency of the response. For example `"priority"`
|
||||
will result in faster completions, but in exchange for a higher price.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `GeminiLiveLLMService` to use the `google-genai` library rather than
|
||||
use WebSockets directly.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `LivekitFrameSerializer` is now deprecated. Use `LiveKitTransport` instead.
|
||||
|
||||
- `pipecat.service.openai_realtime` is now deprecated, use
|
||||
`pipecat.services.openai.realtime` instead or
|
||||
`pipecat.services.azure.realtime` for Azure Realtime.
|
||||
|
||||
- `pipecat.service.aws_nova_sonic` is now deprecated, use
|
||||
`pipecat.services.aws.nova_sonic` instead.
|
||||
|
||||
- `GeminiMultimodalLiveLLMService` is now deprecated, use
|
||||
`GeminiLiveLLMService`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `GoogleVertexLLMService` issue that would generate an error if no
|
||||
token information was returned.
|
||||
|
||||
- `GeminiLiveLLMService` will now end gracefully (i.e. after the bot has
|
||||
finished) upon receiving an `EndFrame`.
|
||||
|
||||
- `GeminiLiveLLMService` will try to seamlessly reconnect when it loses its
|
||||
connection.
|
||||
|
||||
## [0.0.89] - 2025-10-07
|
||||
|
||||
### Fixed
|
||||
@@ -23,8 +231,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added `HumeTTSService` for text-to-speech synthesis using Hume AI's expressive
|
||||
voice models. Provides high-quality, emotionally expressive speech synthesis
|
||||
with support for various voice models. Includes example in
|
||||
`examples/foundational/07ad-interruptible-hume.py`. Use with `uv pip install
|
||||
pipecat-ai[hume]`.
|
||||
`examples/foundational/07ad-interruptible-hume.py`. Use with:
|
||||
`uv pip install pipecat-ai[hume]`.
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -1422,7 +1630,7 @@ quality and critical bugs impacting `ParallelPipelines` functionality.**
|
||||
- Added `session_token` parameter to `AWSNovaSonicLLMService`.
|
||||
|
||||
- Added Gemini Multimodal Live File API for uploading, fetching, listing, and
|
||||
deleting files. See `26f-gemini-multimodal-live-files-api.py` for example usage.
|
||||
deleting files. See `26f-gemini-live-files-api.py` for example usage.
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -3428,7 +3636,7 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
|
||||
- Added the new modalities option and helper function to set Gemini output
|
||||
modalities.
|
||||
|
||||
- Added `examples/foundational/26d-gemini-multimodal-live-text.py` which is
|
||||
- Added `examples/foundational/26d-gemini-live-text.py` which is
|
||||
using Gemini as TEXT modality and using another TTS provider for TTS process.
|
||||
|
||||
### Changed
|
||||
@@ -3615,9 +3823,9 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
|
||||
- Added new foundational examples for `GeminiMultimodalLiveLLMService`:
|
||||
|
||||
- `26-gemini-multimodal-live.py`
|
||||
- `26a-gemini-multimodal-live-transcription.py`
|
||||
- `26b-gemini-multimodal-live-video.py`
|
||||
- `26c-gemini-multimodal-live-video.py`
|
||||
- `26a-gemini-live-transcription.py`
|
||||
- `26b-gemini-live-video.py`
|
||||
- `26c-gemini-live-video.py`
|
||||
|
||||
- Added `SimliVideoService`. This is an integration for Simli AI avatars.
|
||||
(see https://www.simli.com)
|
||||
|
||||
35
README.md
35
README.md
@@ -3,6 +3,7 @@
|
||||
</div></h1>
|
||||
|
||||
[](https://pypi.org/project/pipecat-ai)  [](https://codecov.io/gh/pipecat-ai/pipecat) [](https://docs.pipecat.ai) [](https://discord.gg/pipecat) [](https://deepwiki.com/pipecat-ai/pipecat)
|
||||
[](https://getmanta.ai/pipecat)
|
||||
|
||||
# 🎙️ Pipecat: Real-Time Voice & Multimodal AI Agents
|
||||
|
||||
@@ -43,6 +44,10 @@ Looking to build structured conversations? Check out [Pipecat Flows](https://git
|
||||
|
||||
Want to build beautiful and engaging experiences? Checkout the [Voice UI Kit](https://github.com/pipecat-ai/voice-ui-kit), a collection of components, hooks and templates for building voice AI applications quickly.
|
||||
|
||||
### 🛠️ CLI
|
||||
|
||||
Create a new project in under a minute with the [Pipecat CLI](https://github.com/pipecat-ai/pipecat-cli). Then use the CLI to monitor and deploy your agent to production.
|
||||
|
||||
### 🔍 Debugging
|
||||
|
||||
Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
|
||||
@@ -51,6 +56,10 @@ Looking for help debugging your pipeline and processors? Check out [Whisker](htt
|
||||
|
||||
Love terminal applications? Check out [Tail](https://github.com/pipecat-ai/tail), a terminal dashboard for Pipecat.
|
||||
|
||||
### 📺️ Pipecat TV Channel
|
||||
|
||||
Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.youtube.com/playlist?list=PLzU2zoMTQIHjqC3v4q2XVSR3hGSzwKFwH) channel.
|
||||
|
||||
## 🎬 See it in action
|
||||
|
||||
<p float="left">
|
||||
@@ -58,24 +67,24 @@ Love terminal applications? Check out [Tail](https://github.com/pipecat-ai/tail)
|
||||
<a href="https://github.com/pipecat-ai/pipecat-examples/tree/main/storytelling-chatbot"><img src="https://raw.githubusercontent.com/pipecat-ai/pipecat-examples/main/storytelling-chatbot/image.png" width="400" /></a>
|
||||
<br/>
|
||||
<a href="https://github.com/pipecat-ai/pipecat-examples/tree/main/translation-chatbot"><img src="https://raw.githubusercontent.com/pipecat-ai/pipecat-examples/main/translation-chatbot/image.png" width="400" /></a>
|
||||
<a href="https://github.com/pipecat-ai/pipecat-examples/tree/main/moondream-chatbot"><img src="https://raw.githubusercontent.com/pipecat-ai/pipecat-examples/main/moondream-chatbot/image.png" width="400" /></a>
|
||||
<a href="https://github.com/pipecat-ai/pipecat/blob/main/examples/foundational/12-describe-video.py"><img src="https://github.com/pipecat-ai/pipecat/blob/main/examples/foundational/assets/moondream.png" width="400" /></a>
|
||||
</p>
|
||||
|
||||
## 🧩 Available services
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Category | Services |
|
||||
| ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter) |
|
||||
| Analytics & Metrics | [OpenTelemetry](https://docs.pipecat.ai/server/utilities/opentelemetry), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter) |
|
||||
| Analytics & Metrics | [OpenTelemetry](https://docs.pipecat.ai/server/utilities/opentelemetry), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
|
||||
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ autodoc_mock_imports = [
|
||||
# Krisp - has build issues on some platforms
|
||||
"pipecat_ai_krisp",
|
||||
"krisp",
|
||||
"krisp_audio",
|
||||
# System-specific GUI libraries
|
||||
"_tkinter",
|
||||
"tkinter",
|
||||
|
||||
@@ -90,6 +90,9 @@ SIMLI_FACE_ID=...
|
||||
# Krisp
|
||||
KRISP_MODEL_PATH=...
|
||||
|
||||
# Krisp Viva
|
||||
KRISP_VIVA_MODEL_PATH=...
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=...
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -58,7 +58,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = CartesiaSTTService(api_key=os.getenv("CARTESIA_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
|
||||
129
examples/foundational/07p-interruptible-krisp-viva.py
Normal file
129
examples/foundational/07p-interruptible-krisp-viva.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.filters.krisp_viva_filter import KrispVivaFilter
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=KrispVivaFilter(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=KrispVivaFilter(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=KrispVivaFilter(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en")
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -48,10 +48,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = CartesiaSTTService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
base_url=os.getenv("CARTESIA_BASE_URL"),
|
||||
)
|
||||
stt = CartesiaSTTService(api_key=os.getenv("CARTESIA_API_KEY"))
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
|
||||
@@ -76,9 +76,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = GoogleVertexLLMService(
|
||||
credentials=os.getenv("GOOGLE_VERTEX_TEST_CREDENTIALS"),
|
||||
params=GoogleVertexLLMService.InputParams(
|
||||
project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"),
|
||||
),
|
||||
project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"),
|
||||
location=os.getenv("GOOGLE_CLOUD_LOCATION"),
|
||||
)
|
||||
# You can aslo register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
|
||||
182
examples/foundational/14x-function-calling-openpipe.py
Normal file
182
examples/foundational/14x-function-calling-openpipe.py
Normal file
@@ -0,0 +1,182 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openpipe.llm import OpenPipeLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
timestamp = int(time.time())
|
||||
llm = OpenPipeLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
openpipe_api_key=os.getenv("OPENPIPE_API_KEY"),
|
||||
tags={"conversation_id": f"pipecat-{timestamp}"},
|
||||
)
|
||||
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -24,14 +24,15 @@ from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
@@ -21,13 +21,14 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.azure.realtime.llm import AzureRealtimeLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
AzureRealtimeLLMService,
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
InputAudioTranscription,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
@@ -22,16 +22,17 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
@@ -25,13 +25,14 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeLLMService,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
@@ -72,7 +72,6 @@ async def save_conversation(params: FunctionCallParams):
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
# todo: extract 'system' into the first message in the list
|
||||
messages = params.context.get_messages()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
|
||||
@@ -90,7 +90,6 @@ async def save_conversation(params: FunctionCallParams):
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
# todo: extract 'system' into the first message in the list
|
||||
messages = params.context.get_messages()
|
||||
# remove the last message (the instruction to save the context)
|
||||
messages.pop()
|
||||
|
||||
@@ -20,10 +20,12 @@ from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws_nova_sonic.aws import AWSNovaSonicLLMService
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -75,7 +77,7 @@ async def save_conversation(params: FunctionCallParams):
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
messages = params.context.get_messages()
|
||||
# remove the last few messages. in reverse order, they are:
|
||||
# - the in progress save tool call
|
||||
# - the invocation of the save tool call
|
||||
@@ -223,13 +225,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
context = LLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
|
||||
@@ -17,7 +17,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -65,7 +65,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
"""
|
||||
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
voice_id="Puck", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
@@ -20,7 +20,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -65,7 +65,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
voice_id="Aoede", # Puck, Charon, Kore, Fenrir, Aoede
|
||||
# system_instruction="Talk like a pirate."
|
||||
@@ -22,7 +22,7 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -122,12 +122,15 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
required=["location"],
|
||||
)
|
||||
search_tool = {"google_search": {}}
|
||||
# KNOWN ISSUE: If using GeminiVertexLiveLLMService, it appears
|
||||
# you cannot use the "google_search" tool alongside other tools.
|
||||
# See https://github.com/googleapis/python-genai/issues/941.
|
||||
tools = ToolsSchema(
|
||||
standard_tools=[weather_function, restaurant_function],
|
||||
custom_tools={AdapterType.GEMINI: [search_tool]},
|
||||
)
|
||||
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
tools=tools,
|
||||
@@ -24,7 +24,7 @@ from pipecat.runner.utils import (
|
||||
maybe_capture_participant_camera,
|
||||
maybe_capture_participant_screen,
|
||||
)
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
@@ -58,7 +58,7 @@ transport_params = {
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
voice_id="Aoede", # Puck, Charon, Kore, Fenrir, Aoede
|
||||
# system_instruction="Talk like a pirate."
|
||||
@@ -20,9 +20,9 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.gemini_multimodal_live.gemini import (
|
||||
GeminiMultimodalLiveLLMService,
|
||||
GeminiMultimodalModalities,
|
||||
from pipecat.services.google.gemini_live.llm import (
|
||||
GeminiLiveLLMService,
|
||||
GeminiModalities,
|
||||
InputParams,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
@@ -80,11 +80,15 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
# KNOWN ISSUE: If using GeminiLiveVertexLLMService, you cannot specify a
|
||||
# modality other than AUDIO (at least not if using the service's default
|
||||
# model, which is a native audio model:
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/live-api/tools#native-audio).
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=SYSTEM_INSTRUCTION,
|
||||
tools=[{"google_search": {}}, {"code_execution": {}}],
|
||||
params=InputParams(modalities=GeminiMultimodalModalities.TEXT),
|
||||
params=InputParams(modalities=GeminiModalities.TEXT),
|
||||
)
|
||||
|
||||
# Optionally, you can set the response modalities via a function
|
||||
@@ -19,7 +19,7 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -83,7 +83,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Initialize the Gemini Multimodal Live model
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
voice_id="Puck", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
system_instruction=system_instruction,
|
||||
@@ -19,9 +19,7 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live.gemini import (
|
||||
GeminiMultimodalLiveLLMService,
|
||||
)
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -110,7 +108,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
"""
|
||||
|
||||
# Initialize Gemini service with File API support
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
voice_id="Charon", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
@@ -9,13 +9,13 @@ from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import Frame, LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -105,7 +105,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
custom_tools={AdapterType.GEMINI: [{"google_search": {}}, {"code_execution": {}}]},
|
||||
)
|
||||
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=SYSTEM_INSTRUCTION,
|
||||
voice_id="Charon", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
191
examples/foundational/26h-gemini-live-vertex-function-calling.py
Normal file
191
examples/foundational/26h-gemini-live-vertex-function-calling.py
Normal file
@@ -0,0 +1,191 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from google.genai.types import HttpOptions
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm_vertex import GeminiLiveVertexLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
system_instruction = """
|
||||
You are a helpful assistant who can answer questions and use tools.
|
||||
|
||||
You have three tools available to you:
|
||||
1. get_current_weather: Use this tool to get the current weather in a specific location.
|
||||
2. get_restaurant_recommendation: Use this tool to get a restaurant recommendation in a specific location.
|
||||
"""
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
# set stop_secs to something roughly similar to the internal setting
|
||||
# of the Multimodal Live api, just to align events. This doesn't really
|
||||
# matter because we can only use the Multimodal Live API's phrase
|
||||
# endpointing, for now.
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
# set stop_secs to something roughly similar to the internal setting
|
||||
# of the Multimodal Live api, just to align events. This doesn't really
|
||||
# matter because we can only use the Multimodal Live API's phrase
|
||||
# endpointing, for now.
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
# set stop_secs to something roughly similar to the internal setting
|
||||
# of the Multimodal Live api, just to align events. This doesn't really
|
||||
# matter because we can only use the Multimodal Live API's phrase
|
||||
# endpointing, for now.
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
# KNOWN ISSUE: If using GeminiVertexLiveLLMService, it appears
|
||||
# you cannot use the "google_search" tool alongside other tools.
|
||||
# See https://github.com/googleapis/python-genai/issues/941.
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
llm = GeminiLiveVertexLLMService(
|
||||
credentials=os.getenv("GOOGLE_VERTEX_TEST_CREDENTIALS"),
|
||||
project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"),
|
||||
location=os.getenv("GOOGLE_CLOUD_LOCATION"),
|
||||
system_instruction=system_instruction,
|
||||
voice_id="Puck", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello."}],
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
204
examples/foundational/26i-gemini-live-graceful-end.py
Normal file
204
examples/foundational/26i-gemini-live-graceful-end.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import EndTaskFrame, LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
async def end_conversation(params: FunctionCallParams):
|
||||
await params.result_callback({"success": True})
|
||||
await params.llm.push_frame(EndTaskFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
|
||||
system_instruction = """
|
||||
You are a helpful assistant who can answer questions and use tools.
|
||||
|
||||
You have three tools available to you:
|
||||
1. get_current_weather: Use this tool to get the current weather in a specific location.
|
||||
2. get_restaurant_recommendation: Use this tool to get a restaurant recommendation in a specific location.
|
||||
3. end_conversation: Use this tool to gracefully end the conversation.
|
||||
|
||||
After you've responded to the user three times, do two things, in order:
|
||||
1. Politely let them know that that's all the time you have today and say goodbye.
|
||||
2. Call the end_conversation tool to gracefully end the conversation.
|
||||
"""
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
# set stop_secs to something roughly similar to the internal setting
|
||||
# of the Multimodal Live api, just to align events. This doesn't really
|
||||
# matter because we can only use the Multimodal Live API's phrase
|
||||
# endpointing, for now.
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
# set stop_secs to something roughly similar to the internal setting
|
||||
# of the Multimodal Live api, just to align events. This doesn't really
|
||||
# matter because we can only use the Multimodal Live API's phrase
|
||||
# endpointing, for now.
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
# set stop_secs to something roughly similar to the internal setting
|
||||
# of the Multimodal Live api, just to align events. This doesn't really
|
||||
# matter because we can only use the Multimodal Live API's phrase
|
||||
# endpointing, for now.
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
end_conversation_function = FunctionSchema(
|
||||
name="end_conversation",
|
||||
description="Gracefully end the conversation",
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
search_tool = {"google_search": {}}
|
||||
tools = ToolsSchema(
|
||||
standard_tools=[weather_function, restaurant_function, end_conversation_function],
|
||||
custom_tools={AdapterType.GEMINI: [search_tool]},
|
||||
)
|
||||
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
llm.register_function("end_conversation", end_conversation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello."}],
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -18,10 +18,11 @@ from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws_nova_sonic import AWSNovaSonicLLMService
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -119,9 +120,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
# Set up context and context management.
|
||||
# AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to
|
||||
# what's expected by Nova Sonic.
|
||||
context = OpenAILLMContext(
|
||||
context = LLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
{
|
||||
@@ -131,7 +130,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Build the pipeline
|
||||
pipeline = Pipeline(
|
||||
|
||||
@@ -20,7 +20,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gemini_multimodal_live import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams, DailyTransport
|
||||
|
||||
@@ -94,7 +94,7 @@ Respond to what the user said in a creative and helpful way. Keep your responses
|
||||
|
||||
|
||||
async def run_bot(pipecat_transport):
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
voice_id="Puck", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
transcribe_user_audio=True,
|
||||
|
||||
142
examples/foundational/47-sentry-metrics.py
Normal file
142
examples/foundational/47-sentry-metrics.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import sentry_sdk
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.metrics.sentry import SentryMetrics
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Initialize Sentry
|
||||
sentry_sdk.init(
|
||||
dsn=os.getenv("SENTRY_DSN"),
|
||||
traces_sample_rate=1.0,
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
metrics=SentryMetrics(),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
metrics=SentryMetrics(),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
metrics=SentryMetrics(),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -105,7 +105,7 @@ uv run 07-interruptible.py -t twilio -x NGROK_HOST_NAME
|
||||
### Vision & Multimodal
|
||||
|
||||
- **[12a-describe-video-gemini-flash.py](./12a-describe-video-gemini-flash.py)**: Bot describes user's video (Video input, Multimodal LLMs)
|
||||
- **[26c-gemini-multimodal-live-video.py](./26c-gemini-multimodal-live-video.py)**: Gemini with video input (Streaming video, Function calls)
|
||||
- **[26c-gemini-live-video.py](./26c-gemini-live-video.py)**: Gemini with video input (Streaming video, Function calls)
|
||||
|
||||
### Voice & Language
|
||||
|
||||
|
||||
BIN
examples/foundational/assets/moondream.png
Normal file
BIN
examples/foundational/assets/moondream.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 MiB |
@@ -34,10 +34,11 @@ dependencies = [
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai>=1.74.0,<=1.99.1",
|
||||
"openai>=1.74.0,<3",
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
"wait_for2>=0.4.1; python_version<'3.12'",
|
||||
"pipecat-ai-cli"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -55,7 +56,7 @@ azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.19.9" ]
|
||||
daily = [ "daily-python~=0.20.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
@@ -84,7 +85,7 @@ nim = []
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
@@ -102,7 +103,7 @@ silero = [ "onnxruntime>=1.20.1,<2" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.4.0" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.5.0" ]
|
||||
strands = [ "strands-agents>=1.9.1,<2" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
|
||||
@@ -136,6 +136,7 @@ TESTS_14 = [
|
||||
("14r-function-calling-aws.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14v-function-calling-openai.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14w-function-calling-mistral.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14x-function-calling-openpipe.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
# Currently not working.
|
||||
# ("14c-function-calling-together.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
# ("14l-function-calling-deepseek.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
@@ -147,7 +148,10 @@ TESTS_15 = [
|
||||
]
|
||||
|
||||
TESTS_19 = [
|
||||
("19-openai-realtime.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19-openai-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
# OpenAI Realtime not released on Azure yet
|
||||
# ("19a-azure-realtime.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19a-azure-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19b-openai-realtime-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19b-openai-realtime-beta-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
@@ -160,18 +164,18 @@ TESTS_21 = [
|
||||
TESTS_26 = [
|
||||
("26-gemini-multimodal-live.py", PROMPT_SIMPLE_MATH, EVAL_SIMPLE_MATH, BOT_SPEAKS_FIRST),
|
||||
(
|
||||
"26a-gemini-multimodal-live-transcription.py",
|
||||
"26a-gemini-live-transcription.py",
|
||||
PROMPT_SIMPLE_MATH,
|
||||
EVAL_SIMPLE_MATH,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
(
|
||||
"26b-gemini-multimodal-live-function-calling.py",
|
||||
"26b-gemini-live-function-calling.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
("26c-gemini-multimodal-live-video.py", PROMPT_SIMPLE_MATH, EVAL_SIMPLE_MATH, BOT_SPEAKS_FIRST),
|
||||
("26c-gemini-live-video.py", PROMPT_SIMPLE_MATH, EVAL_SIMPLE_MATH, BOT_SPEAKS_FIRST),
|
||||
(
|
||||
"26e-gemini-multimodal-google-search.py",
|
||||
PROMPT_ONLINE_SEARCH,
|
||||
@@ -179,7 +183,13 @@ TESTS_26 = [
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
# Currently not working.
|
||||
# ("26d-gemini-multimodal-live-text.py", PROMPT_SIMPLE_MATH, EVAL_SIMPLE_MATH, BOT_SPEAKS_FIRST),
|
||||
# ("26d-gemini-live-text.py", PROMPT_SIMPLE_MATH, EVAL_SIMPLE_MATH, BOT_SPEAKS_FIRST),
|
||||
(
|
||||
"26h-gemini-live-vertex-function-calling.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
]
|
||||
|
||||
TESTS_27 = [
|
||||
|
||||
@@ -6,13 +6,47 @@
|
||||
|
||||
"""AWS Nova Sonic LLM adapter for Pipecat."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""Roles supported in AWS Nova Sonic conversations.
|
||||
|
||||
Parameters:
|
||||
SYSTEM: System-level messages (not used in conversation history).
|
||||
USER: Messages sent by the user.
|
||||
ASSISTANT: Messages sent by the assistant.
|
||||
TOOL: Messages sent by tools (not used in conversation history).
|
||||
"""
|
||||
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistoryMessage:
|
||||
"""A single message in AWS Nova Sonic conversation history.
|
||||
|
||||
Parameters:
|
||||
role: The role of the message sender (USER or ASSISTANT only).
|
||||
text: The text content of the message.
|
||||
"""
|
||||
|
||||
role: Role # only USER and ASSISTANT
|
||||
text: str
|
||||
|
||||
|
||||
class AWSNovaSonicLLMInvocationParams(TypedDict):
|
||||
@@ -21,7 +55,9 @@ class AWSNovaSonicLLMInvocationParams(TypedDict):
|
||||
This is a placeholder until support for universal LLMContext machinery is added for AWS Nova Sonic.
|
||||
"""
|
||||
|
||||
pass
|
||||
system_instruction: Optional[str]
|
||||
messages: List[AWSNovaSonicConversationHistoryMessage]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
@@ -34,7 +70,7 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
return "aws-nova-sonic"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
@@ -47,7 +83,13 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Nova Sonic's LLM API.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools) or [],
|
||||
}
|
||||
|
||||
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
|
||||
"""Get messages from a universal LLM context in a format ready for logging about AWS Nova Sonic.
|
||||
@@ -62,7 +104,75 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
Returns:
|
||||
List of messages in a format ready for logging about AWS Nova Sonic.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
return self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
|
||||
messages: List[AWSNovaSonicConversationHistoryMessage]
|
||||
system_instruction: Optional[str] = None
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
system_instruction = None
|
||||
messages = []
|
||||
|
||||
# Bail if there are no messages
|
||||
if not universal_context_messages:
|
||||
return self.ConvertedMessages()
|
||||
|
||||
universal_context_messages = copy.deepcopy(universal_context_messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into "instruction"
|
||||
if universal_context_messages[0].get("role") == "system":
|
||||
system = universal_context_messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
system_instruction = content[0].get("text")
|
||||
if system_instruction:
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
# Process remaining messages to fill out conversation history.
|
||||
# Nova Sonic supports "user" and "assistant" messages in history.
|
||||
for universal_context_message in universal_context_messages:
|
||||
message = self._from_universal_context_message(universal_context_message)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system_instruction=system_instruction)
|
||||
|
||||
def _from_universal_context_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
|
||||
"""Convert standard message format to Nova Sonic format.
|
||||
|
||||
Args:
|
||||
message: Standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Nova Sonic conversation history message, or None if not convertible.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if message.get("role") == "user" or message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
# There won't be content if this is an assistant tool call entry.
|
||||
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
|
||||
# history
|
||||
if content:
|
||||
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
|
||||
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
|
||||
# Sonic conversation history
|
||||
|
||||
@staticmethod
|
||||
def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
|
||||
@@ -87,9 +87,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Includes both converted standard tools and any custom Gemini-specific tools.
|
||||
"""
|
||||
functions_schema = tools_schema.standard_tools
|
||||
formatted_standard_tools = [
|
||||
{"function_declarations": [func.to_default_dict() for func in functions_schema]}
|
||||
]
|
||||
formatted_standard_tools = (
|
||||
[{"function_declarations": [func.to_default_dict() for func in functions_schema]}]
|
||||
if functions_schema
|
||||
else []
|
||||
)
|
||||
custom_gemini_tools = []
|
||||
if tools_schema.custom_tools:
|
||||
custom_gemini_tools = tools_schema.custom_tools.get(AdapterType.GEMINI, [])
|
||||
|
||||
193
src/pipecat/audio/filters/krisp_viva_filter.py
Normal file
193
src/pipecat/audio/filters/krisp_viva_filter.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Krisp noise reduction audio filter for Pipecat.
|
||||
|
||||
This module provides an audio filter implementation using Krisp VIVA SDK.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
|
||||
try:
|
||||
import krisp_audio
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the Krisp filter, you need to install krisp_audio.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def _log_callback(log_message, log_level):
|
||||
logger.info(f"[{log_level}] {log_message}")
|
||||
|
||||
|
||||
class KrispVivaFilter(BaseAudioFilter):
|
||||
"""Audio filter using the Krisp VIVA SDK.
|
||||
|
||||
Provides real-time noise reduction for audio streams using Krisp's
|
||||
proprietary noise suppression algorithms. This filter requires a
|
||||
valid Krisp model file to operate.
|
||||
|
||||
Supported sample rates:
|
||||
- 8000 Hz
|
||||
- 16000 Hz
|
||||
- 24000 Hz
|
||||
- 32000 Hz
|
||||
- 44100 Hz
|
||||
- 48000 Hz
|
||||
"""
|
||||
|
||||
# Initialize Krisp Audio SDK globally
|
||||
krisp_audio.globalInit("", _log_callback, krisp_audio.LogLevel.Off)
|
||||
SDK_VERSION = krisp_audio.getVersion()
|
||||
logger.debug(
|
||||
f"Krisp Audio Python SDK Version: {SDK_VERSION.major}."
|
||||
f"{SDK_VERSION.minor}.{SDK_VERSION.patch}"
|
||||
)
|
||||
|
||||
SAMPLE_RATES = {
|
||||
8000: krisp_audio.SamplingRate.Sr8000Hz,
|
||||
16000: krisp_audio.SamplingRate.Sr16000Hz,
|
||||
24000: krisp_audio.SamplingRate.Sr24000Hz,
|
||||
32000: krisp_audio.SamplingRate.Sr32000Hz,
|
||||
44100: krisp_audio.SamplingRate.Sr44100Hz,
|
||||
48000: krisp_audio.SamplingRate.Sr48000Hz,
|
||||
}
|
||||
|
||||
FRAME_SIZE_MS = 10 # Krisp requires audio frames of 10ms duration for processing.
|
||||
|
||||
def __init__(self, model_path: str = None, noise_suppression_level: int = 100) -> None:
|
||||
"""Initialize the Krisp noise reduction filter.
|
||||
|
||||
Args:
|
||||
model_path: Path to the Krisp model file (.kef extension).
|
||||
If None, uses KRISP_VIVA_MODEL_PATH environment variable.
|
||||
noise_suppression_level: Noise suppression level.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_path is not provided and KRISP_VIVA_MODEL_PATH is not set.
|
||||
Exception: If model file doesn't have .kef extension.
|
||||
FileNotFoundError: If model file doesn't exist.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Set model path, checking environment if not specified
|
||||
self._model_path = model_path or os.getenv("KRISP_VIVA_MODEL_PATH")
|
||||
if not self._model_path:
|
||||
logger.error("Model path is not provided and KRISP_VIVA_MODEL_PATH is not set.")
|
||||
raise ValueError("Model path for KrispAudioProcessor must be provided.")
|
||||
|
||||
if not self._model_path.endswith(".kef"):
|
||||
raise Exception("Model is expected with .kef extension")
|
||||
|
||||
if not os.path.isfile(self._model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {self._model_path}")
|
||||
|
||||
self._filtering = True
|
||||
self._session = None
|
||||
self._samples_per_frame = None
|
||||
self._noise_suppression_level = noise_suppression_level
|
||||
|
||||
# Audio buffer to accumulate samples for complete frames
|
||||
self._audio_buffer = bytearray()
|
||||
|
||||
def _int_to_sample_rate(self, sample_rate):
|
||||
"""Convert integer sample rate to krisp_audio SamplingRate enum.
|
||||
|
||||
Args:
|
||||
sample_rate: Sample rate as integer
|
||||
|
||||
Returns:
|
||||
krisp_audio.SamplingRate enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If sample rate is not supported
|
||||
"""
|
||||
if sample_rate not in self.SAMPLE_RATES:
|
||||
raise ValueError("Unsupported sample rate")
|
||||
return self.SAMPLE_RATES[sample_rate]
|
||||
|
||||
async def start(self, sample_rate: int):
|
||||
"""Initialize the Krisp processor with the transport's sample rate.
|
||||
|
||||
Args:
|
||||
sample_rate: The sample rate of the input transport in Hz.
|
||||
"""
|
||||
model_info = krisp_audio.ModelInfo()
|
||||
model_info.path = self._model_path
|
||||
|
||||
nc_cfg = krisp_audio.NcSessionConfig()
|
||||
nc_cfg.inputSampleRate = self._int_to_sample_rate(sample_rate)
|
||||
nc_cfg.inputFrameDuration = krisp_audio.FrameDuration.Fd10ms
|
||||
nc_cfg.outputSampleRate = nc_cfg.inputSampleRate
|
||||
nc_cfg.modelInfo = model_info
|
||||
|
||||
self._samples_per_frame = int((sample_rate * self.FRAME_SIZE_MS) / 1000)
|
||||
self._session = krisp_audio.NcInt16.create(nc_cfg)
|
||||
|
||||
async def stop(self):
|
||||
"""Clean up the Krisp processor when stopping."""
|
||||
self._session = None
|
||||
|
||||
async def process_frame(self, frame: FilterControlFrame):
|
||||
"""Process control frames to enable/disable filtering.
|
||||
|
||||
Args:
|
||||
frame: The control frame containing filter commands.
|
||||
"""
|
||||
if isinstance(frame, FilterEnableFrame):
|
||||
self._filtering = frame.enable
|
||||
|
||||
async def filter(self, audio: bytes) -> bytes:
|
||||
"""Apply Krisp noise reduction to audio data.
|
||||
|
||||
Args:
|
||||
audio: Raw audio data as bytes to be filtered.
|
||||
|
||||
Returns:
|
||||
Noise-reduced audio data as bytes.
|
||||
"""
|
||||
if not self._filtering:
|
||||
return audio
|
||||
|
||||
# Add incoming audio to our buffer
|
||||
self._audio_buffer.extend(audio)
|
||||
|
||||
# Calculate how many complete frames we can process
|
||||
total_samples = len(self._audio_buffer) // 2 # 2 bytes per int16 sample
|
||||
num_complete_frames = total_samples // self._samples_per_frame
|
||||
|
||||
if num_complete_frames == 0:
|
||||
# Not enough samples for a complete frame yet, return empty
|
||||
return b""
|
||||
|
||||
# Calculate how many bytes we need for complete frames
|
||||
complete_samples_count = num_complete_frames * self._samples_per_frame
|
||||
bytes_to_process = complete_samples_count * 2 # 2 bytes per sample
|
||||
|
||||
# Extract the bytes we can process
|
||||
audio_to_process = bytes(self._audio_buffer[:bytes_to_process])
|
||||
|
||||
# Remove processed bytes from buffer, keep the remainder
|
||||
self._audio_buffer = self._audio_buffer[bytes_to_process:]
|
||||
|
||||
# Process the complete frames
|
||||
samples = np.frombuffer(audio_to_process, dtype=np.int16)
|
||||
frames = samples.reshape(-1, self._samples_per_frame)
|
||||
processed_samples = np.empty_like(samples)
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
cleaned_frame = self._session.process(frame, self._noise_suppression_level)
|
||||
processed_samples[i * self._samples_per_frame : (i + 1) * self._samples_per_frame] = (
|
||||
cleaned_frame
|
||||
)
|
||||
|
||||
return processed_samples.tobytes()
|
||||
@@ -70,11 +70,15 @@ class PipelineRunner(BaseObject):
|
||||
"""
|
||||
logger.debug(f"Runner {self} started running {task}")
|
||||
self._tasks[task.name] = task
|
||||
params = PipelineTaskParams(loop=self._loop)
|
||||
|
||||
# PipelineTask handles asyncio.CancelledError to shutdown the pipeline
|
||||
# properly and re-raises it in case there's more cleanup to do.
|
||||
try:
|
||||
params = PipelineTaskParams(loop=self._loop)
|
||||
await task.run(params)
|
||||
except asyncio.CancelledError:
|
||||
await self._cancel()
|
||||
pass
|
||||
|
||||
del self._tasks[task.name]
|
||||
|
||||
# Cleanup base object.
|
||||
|
||||
@@ -138,6 +138,8 @@ class PipelineTask(BasePipelineTask):
|
||||
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
|
||||
the frame if they need to handle specific cases.
|
||||
|
||||
- on_pipeline_error: Called when an error occurs with ErrorFrame
|
||||
|
||||
Example::
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
@@ -148,9 +150,17 @@ class PipelineTask(BasePipelineTask):
|
||||
async def on_pipeline_idle_timeout(task):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(task, frame):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_error")
|
||||
async def on_pipeline_error(task, frame):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -259,6 +269,9 @@ class PipelineTask(BasePipelineTask):
|
||||
# StopFrame) has been received at the end of the pipeline.
|
||||
self._pipeline_end_event = asyncio.Event()
|
||||
|
||||
# This event is set when the pipeline truly finishes.
|
||||
self._pipeline_finished_event = asyncio.Event()
|
||||
|
||||
# This is the final pipeline. It is composed of a source processor,
|
||||
# followed by the user pipeline, and ending with a sink processor. The
|
||||
# source allows us to receive and react to upstream frames, and the sink
|
||||
@@ -288,6 +301,7 @@ class PipelineTask(BasePipelineTask):
|
||||
self._register_event_handler("on_pipeline_ended")
|
||||
self._register_event_handler("on_pipeline_cancelled")
|
||||
self._register_event_handler("on_pipeline_finished")
|
||||
self._register_event_handler("on_pipeline_error")
|
||||
|
||||
@property
|
||||
def params(self) -> PipelineParams:
|
||||
@@ -390,11 +404,7 @@ class PipelineTask(BasePipelineTask):
|
||||
await self.queue_frame(EndFrame())
|
||||
|
||||
async def cancel(self):
|
||||
"""Immediately stop the running pipeline.
|
||||
|
||||
Cancels all running tasks and stops frame processing without
|
||||
waiting for completion.
|
||||
"""
|
||||
"""Request the running pipeline to cancel."""
|
||||
if not self._finished:
|
||||
await self._cancel()
|
||||
|
||||
@@ -406,51 +416,38 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
if self.has_finished():
|
||||
return
|
||||
cleanup_pipeline = True
|
||||
|
||||
# Setup processors.
|
||||
await self._setup(params)
|
||||
|
||||
# Create all main tasks and wait for the main push task. This is the
|
||||
# task that pushes frames to the very beginning of our pipeline (i.e. to
|
||||
# our controlled source processor).
|
||||
await self._create_tasks()
|
||||
|
||||
try:
|
||||
# Setup processors.
|
||||
await self._setup(params)
|
||||
|
||||
# Create all main tasks and wait of the main push task. This is the
|
||||
# task that pushes frames to the very beginning of our pipeline (our
|
||||
# controlled source processor).
|
||||
push_task = await self._create_tasks()
|
||||
await push_task
|
||||
|
||||
# We have already cleaned up the pipeline inside the task.
|
||||
cleanup_pipeline = False
|
||||
|
||||
# Pipeline has finished nicely.
|
||||
self._finished = True
|
||||
# Wait for pipeline to finish.
|
||||
await self._wait_for_pipeline_finished()
|
||||
except asyncio.CancelledError:
|
||||
# Raise exception back to the pipeline runner so it can cancel this
|
||||
# task properly.
|
||||
logger.debug(f"Pipeline task {self} got cancelled from outside...")
|
||||
# We have been cancelled from outside, let's just cancel everything.
|
||||
await self._cancel()
|
||||
# Wait again for pipeline to finish. This time we have really
|
||||
# cancelled, so it should really finish.
|
||||
await self._wait_for_pipeline_finished()
|
||||
# Re-raise in case there's more cleanup to do.
|
||||
raise
|
||||
finally:
|
||||
# We can reach this point for different reasons:
|
||||
#
|
||||
# 1. The task has finished properly (e.g. `EndFrame`).
|
||||
# 2. By calling `PipelineTask.cancel()`.
|
||||
# 3. By asyncio task cancellation.
|
||||
#
|
||||
# Case (1) will execute the code below without issues because
|
||||
# `self._finished` is true.
|
||||
#
|
||||
# Case (2) will execute the code below without issues because
|
||||
# `self._cancelled` is true.
|
||||
#
|
||||
# Case (3) will raise the exception above (because we are cancelling
|
||||
# the asyncio task). This will be then captured by the
|
||||
# `PipelineRunner` which will call `PipelineTask.cancel()` and
|
||||
# therefore becoming case (2).
|
||||
if self._finished or self._cancelled:
|
||||
logger.debug(f"Pipeline task {self} is finishing cleanup...")
|
||||
await self._cancel_tasks()
|
||||
await self._cleanup(cleanup_pipeline)
|
||||
if self._check_dangling_tasks:
|
||||
self._print_dangling_tasks()
|
||||
self._finished = True
|
||||
logger.debug(f"Pipeline task {self} has finished")
|
||||
# 1. The pipeline task has finished (try case).
|
||||
# 2. By an asyncio task cancellation (except case).
|
||||
logger.debug(f"Pipeline task {self} is finishing...")
|
||||
await self._cancel_tasks()
|
||||
if self._check_dangling_tasks:
|
||||
self._print_dangling_tasks()
|
||||
self._finished = True
|
||||
logger.debug(f"Pipeline task {self} has finished")
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
"""Queue a single frame to be pushed down the pipeline.
|
||||
@@ -478,19 +475,7 @@ class PipelineTask(BasePipelineTask):
|
||||
if not self._cancelled:
|
||||
logger.debug(f"Cancelling pipeline task {self}")
|
||||
self._cancelled = True
|
||||
cancel_frame = CancelFrame()
|
||||
# Make sure everything is cleaned up downstream. This is sent
|
||||
# out-of-band from the main streaming task which is what we want since
|
||||
# we want to cancel right away.
|
||||
await self._pipeline.queue_frame(cancel_frame)
|
||||
# Wait for CancelFrame to make it through the pipeline.
|
||||
await self._wait_for_pipeline_end(cancel_frame)
|
||||
# Only cancel the push task, we don't want to be able to process any
|
||||
# other frame after cancel. Everything else will be cancelled in
|
||||
# run().
|
||||
if self._process_push_task:
|
||||
await self._task_manager.cancel_task(self._process_push_task)
|
||||
self._process_push_task = None
|
||||
await self.queue_frame(CancelFrame())
|
||||
|
||||
async def _create_tasks(self):
|
||||
"""Create and start all pipeline processing tasks."""
|
||||
@@ -592,6 +577,17 @@ class PipelineTask(BasePipelineTask):
|
||||
|
||||
self._pipeline_end_event.clear()
|
||||
|
||||
# We are really done.
|
||||
self._pipeline_finished_event.set()
|
||||
|
||||
async def _wait_for_pipeline_finished(self):
|
||||
await self._pipeline_finished_event.wait()
|
||||
self._pipeline_finished_event.clear()
|
||||
# Make sure we wait for the main task to complete.
|
||||
if self._process_push_task:
|
||||
await self._process_push_task
|
||||
self._process_push_task = None
|
||||
|
||||
async def _setup(self, params: PipelineTaskParams):
|
||||
"""Set up the pipeline task and all processors."""
|
||||
mgr_params = TaskManagerParams(loop=params.loop)
|
||||
@@ -694,12 +690,11 @@ class PipelineTask(BasePipelineTask):
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
await self._call_event_handler("on_pipeline_error", frame)
|
||||
if frame.fatal:
|
||||
logger.error(f"A fatal error occurred: {frame}")
|
||||
# Cancel all tasks downstream.
|
||||
await self.queue_frame(CancelFrame())
|
||||
# Tell the task we should stop.
|
||||
await self.queue_frame(StopTaskFrame())
|
||||
else:
|
||||
logger.warning(f"{self}: Something went wrong: {frame}")
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@ service-specific adapter.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, TypeAlias, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -31,6 +32,9 @@ from PIL import Image
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
|
||||
# "Re-export" types from OpenAI that we're using as universal context types.
|
||||
# NOTE: if universal message types need to someday diverge from OpenAI's, we
|
||||
# should consider managing our own definitions. But we should do so carefully,
|
||||
@@ -65,6 +69,26 @@ class LLMContext:
|
||||
and content formatting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_openai_context(openai_context: "OpenAILLMContext") -> "LLMContext":
|
||||
"""Create a universal LLM context from an OpenAI-specific context.
|
||||
|
||||
NOTE: this should only be used internally, for facilitating migration
|
||||
from OpenAILLMContext to LLMContext. New user code should use
|
||||
LLMContext directly.
|
||||
|
||||
Args:
|
||||
openai_context: The OpenAI LLM context to convert.
|
||||
|
||||
Returns:
|
||||
New LLMContext instance with converted messages and settings.
|
||||
"""
|
||||
return LLMContext(
|
||||
messages=openai_context.get_messages(),
|
||||
tools=openai_context.tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[LLMContextMessage]] = None,
|
||||
|
||||
@@ -877,6 +877,8 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
"""
|
||||
while True:
|
||||
(frame, direction, callback) = await self.__input_queue.get()
|
||||
|
||||
if self.__should_block_system_frames and self.__input_event:
|
||||
logger.trace(f"{self}: system frame processing paused")
|
||||
await self.__input_event.wait()
|
||||
@@ -884,8 +886,6 @@ class FrameProcessor(BaseObject):
|
||||
self.__should_block_system_frames = False
|
||||
logger.trace(f"{self}: system frame processing resumed")
|
||||
|
||||
(frame, direction, callback) = await self.__input_queue.get()
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
elif self.__process_queue:
|
||||
@@ -900,6 +900,8 @@ class FrameProcessor(BaseObject):
|
||||
async def __process_frame_task_handler(self):
|
||||
"""Handle non-system frames from the process queue."""
|
||||
while True:
|
||||
(frame, direction, callback) = await self.__process_queue.get()
|
||||
|
||||
if self.__should_block_frames and self.__process_event:
|
||||
logger.trace(f"{self}: frame processing paused")
|
||||
await self.__process_event.wait()
|
||||
@@ -907,8 +909,6 @@ class FrameProcessor(BaseObject):
|
||||
self.__should_block_frames = False
|
||||
logger.trace(f"{self}: frame processing resumed")
|
||||
|
||||
(frame, direction, callback) = await self.__process_queue.get()
|
||||
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
|
||||
self.__process_queue.task_done()
|
||||
|
||||
@@ -82,6 +82,7 @@ async def configure(
|
||||
sip_enable_video: Optional[bool] = False,
|
||||
sip_num_endpoints: Optional[int] = 1,
|
||||
sip_codecs: Optional[Dict[str, List[str]]] = None,
|
||||
room_properties: Optional[DailyRoomProperties] = None,
|
||||
) -> DailyRoomConfig:
|
||||
"""Configure Daily room URL and token with optional SIP capabilities.
|
||||
|
||||
@@ -99,6 +100,10 @@ async def configure(
|
||||
sip_num_endpoints: Number of allowed SIP endpoints.
|
||||
sip_codecs: Codecs to support for audio and video. If None, uses Daily defaults.
|
||||
Example: {"audio": ["OPUS"], "video": ["H264"]}
|
||||
room_properties: Optional DailyRoomProperties to use instead of building from
|
||||
individual parameters. When provided, this overrides room_exp_duration and
|
||||
SIP-related parameters. If not provided, properties are built from the
|
||||
individual parameters as before.
|
||||
|
||||
Returns:
|
||||
DailyRoomConfig: Object with room_url, token, and optional sip_endpoint.
|
||||
@@ -115,6 +120,13 @@ async def configure(
|
||||
# SIP-enabled room
|
||||
sip_config = await configure(session, sip_caller_phone="+15551234567")
|
||||
print(f"SIP endpoint: {sip_config.sip_endpoint}")
|
||||
|
||||
# Custom room properties with recording enabled
|
||||
custom_props = DailyRoomProperties(
|
||||
enable_recording="cloud",
|
||||
max_participants=2,
|
||||
)
|
||||
config = await configure(session, room_properties=custom_props)
|
||||
"""
|
||||
# Check for required API key
|
||||
api_key = os.getenv("DAILY_API_KEY")
|
||||
@@ -124,9 +136,32 @@ async def configure(
|
||||
"Get your API key from https://dashboard.daily.co/developers"
|
||||
)
|
||||
|
||||
# Warn if both room_properties and individual parameters are provided
|
||||
if room_properties is not None:
|
||||
individual_params_provided = any(
|
||||
[
|
||||
room_exp_duration != 2.0,
|
||||
token_exp_duration != 2.0,
|
||||
sip_caller_phone is not None,
|
||||
sip_enable_video is not False,
|
||||
sip_num_endpoints != 1,
|
||||
sip_codecs is not None,
|
||||
]
|
||||
)
|
||||
if individual_params_provided:
|
||||
logger.warning(
|
||||
"Both room_properties and individual parameters (room_exp_duration, token_exp_duration, "
|
||||
"sip_*) were provided. The room_properties will be used and individual parameters "
|
||||
"will be ignored."
|
||||
)
|
||||
|
||||
# Determine if SIP mode is enabled
|
||||
sip_enabled = sip_caller_phone is not None
|
||||
|
||||
# If room_properties is provided, check if it has SIP configuration
|
||||
if room_properties and room_properties.sip:
|
||||
sip_enabled = True
|
||||
|
||||
daily_rest_helper = DailyRESTHelper(
|
||||
daily_api_key=api_key,
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
@@ -150,27 +185,29 @@ async def configure(
|
||||
room_name = f"{room_prefix}-{uuid.uuid4().hex[:8]}"
|
||||
logger.info(f"Creating new Daily room: {room_name}")
|
||||
|
||||
# Calculate expiration time
|
||||
expiration_time = time.time() + (room_exp_duration * 60 * 60)
|
||||
# Use provided room_properties or build from parameters
|
||||
if room_properties is None:
|
||||
# Calculate expiration time
|
||||
expiration_time = time.time() + (room_exp_duration * 60 * 60)
|
||||
|
||||
# Create room properties
|
||||
room_properties = DailyRoomProperties(
|
||||
exp=expiration_time,
|
||||
eject_at_room_exp=True,
|
||||
)
|
||||
|
||||
# Add SIP configuration if enabled
|
||||
if sip_enabled:
|
||||
sip_params = DailyRoomSipParams(
|
||||
display_name=sip_caller_phone,
|
||||
video=sip_enable_video,
|
||||
sip_mode="dial-in",
|
||||
num_endpoints=sip_num_endpoints,
|
||||
codecs=sip_codecs,
|
||||
# Create room properties
|
||||
room_properties = DailyRoomProperties(
|
||||
exp=expiration_time,
|
||||
eject_at_room_exp=True,
|
||||
)
|
||||
room_properties.sip = sip_params
|
||||
room_properties.enable_dialout = True # Enable outbound calls if needed
|
||||
room_properties.start_video_off = not sip_enable_video # Voice-only by default
|
||||
|
||||
# Add SIP configuration if enabled
|
||||
if sip_enabled:
|
||||
sip_params = DailyRoomSipParams(
|
||||
display_name=sip_caller_phone,
|
||||
video=sip_enable_video,
|
||||
sip_mode="dial-in",
|
||||
num_endpoints=sip_num_endpoints,
|
||||
codecs=sip_codecs,
|
||||
)
|
||||
room_properties.sip = sip_params
|
||||
room_properties.enable_dialout = True # Enable outbound calls if needed
|
||||
room_properties.start_video_off = not sip_enable_video # Voice-only by default
|
||||
|
||||
# Create room parameters
|
||||
room_params = DailyRoomParams(name=room_name, properties=room_properties)
|
||||
|
||||
@@ -67,12 +67,17 @@ To run locally:
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
from http import HTTPMethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
import aiohttp
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.runner.types import (
|
||||
@@ -98,6 +103,12 @@ except ImportError as e:
|
||||
load_dotenv(override=True)
|
||||
os.environ["ENV"] = "local"
|
||||
|
||||
TELEPHONY_TRANSPORTS = ["twilio", "telnyx", "plivo", "exotel"]
|
||||
|
||||
RUNNER_DOWNLOADS_FOLDER: Optional[str] = None
|
||||
RUNNER_HOST: str = "localhost"
|
||||
RUNNER_PORT: int = 7860
|
||||
|
||||
|
||||
def _get_bot_module():
|
||||
"""Get the bot module from the calling script."""
|
||||
@@ -152,7 +163,13 @@ async def _run_telephony_bot(websocket: WebSocket):
|
||||
|
||||
|
||||
def _create_server_app(
|
||||
transport_type: str, host: str = "localhost", proxy: str = None, esp32_mode: bool = False
|
||||
*,
|
||||
transport_type: str,
|
||||
host: str = "localhost",
|
||||
proxy: str,
|
||||
esp32_mode: bool = False,
|
||||
whatsapp_enabled: bool = False,
|
||||
folder: Optional[str] = None,
|
||||
):
|
||||
"""Create FastAPI app with transport-specific routes."""
|
||||
app = FastAPI()
|
||||
@@ -167,25 +184,30 @@ def _create_server_app(
|
||||
|
||||
# Set up transport-specific routes
|
||||
if transport_type == "webrtc":
|
||||
_setup_webrtc_routes(app, esp32_mode=esp32_mode, host=host)
|
||||
_setup_whatsapp_routes(app)
|
||||
_setup_webrtc_routes(app, esp32_mode=esp32_mode, host=host, folder=folder)
|
||||
if whatsapp_enabled:
|
||||
_setup_whatsapp_routes(app)
|
||||
elif transport_type == "daily":
|
||||
_setup_daily_routes(app)
|
||||
elif transport_type in ["twilio", "telnyx", "plivo", "exotel"]:
|
||||
_setup_telephony_routes(app, transport_type, proxy)
|
||||
elif transport_type in TELEPHONY_TRANSPORTS:
|
||||
_setup_telephony_routes(app, transport_type=transport_type, proxy=proxy)
|
||||
else:
|
||||
logger.warning(f"Unknown transport type: {transport_type}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "localhost"):
|
||||
def _setup_webrtc_routes(
|
||||
app: FastAPI, *, esp32_mode: bool = False, host: str = "localhost", folder: Optional[str] = None
|
||||
):
|
||||
"""Set up WebRTC-specific routes."""
|
||||
try:
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.request_handler import (
|
||||
IceCandidate,
|
||||
SmallWebRTCPatchRequest,
|
||||
SmallWebRTCRequest,
|
||||
SmallWebRTCRequestHandler,
|
||||
)
|
||||
@@ -193,6 +215,16 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
logger.error(f"WebRTC transport dependencies not installed: {e}")
|
||||
return
|
||||
|
||||
class IceConfig(TypedDict):
|
||||
iceServers: List[IceServer]
|
||||
|
||||
class StartBotResult(TypedDict, total=False):
|
||||
sessionId: str
|
||||
iceConfig: Optional[IceConfig]
|
||||
|
||||
# In-memory store of active sessions: session_id -> session info
|
||||
active_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Mount the frontend
|
||||
app.mount("/client", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -201,6 +233,21 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
"""Redirect root requests to client interface."""
|
||||
return RedirectResponse(url="/client/")
|
||||
|
||||
@app.get("/files/{filename:path}")
|
||||
async def download_file(filename: str):
|
||||
"""Handle file downloads."""
|
||||
if not folder:
|
||||
logger.warning(f"Attempting to dowload {filename}, but downloads folder not setup.")
|
||||
return
|
||||
|
||||
file_path = Path(folder) / filename
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(404)
|
||||
|
||||
media_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
return FileResponse(path=file_path, media_type=media_type, filename=filename)
|
||||
|
||||
# Initialize the SmallWebRTC request handler
|
||||
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
|
||||
esp32_mode=esp32_mode, host=host
|
||||
@@ -223,6 +270,74 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
)
|
||||
return answer
|
||||
|
||||
@app.patch("/api/offer")
|
||||
async def ice_candidate(request: SmallWebRTCPatchRequest):
|
||||
"""Handle WebRTC new ice candidate requests."""
|
||||
logger.debug(f"Received patch request: {request}")
|
||||
await small_webrtc_handler.handle_patch_request(request)
|
||||
return {"status": "success"}
|
||||
|
||||
@app.post("/start")
|
||||
async def rtvi_start(request: Request):
|
||||
"""Mimic Pipecat Cloud's /start endpoint."""
|
||||
# Parse the request body
|
||||
try:
|
||||
request_data = await request.json()
|
||||
logger.debug(f"Received request: {request_data}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse request body: {e}")
|
||||
request_data = {}
|
||||
|
||||
# Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud
|
||||
session_id = str(uuid.uuid4())
|
||||
active_sessions[session_id] = request_data
|
||||
|
||||
result: StartBotResult = {"sessionId": session_id}
|
||||
if request_data.get("enableDefaultIceServers"):
|
||||
result["iceConfig"] = IceConfig(
|
||||
iceServers=[IceServer(urls="stun:stun.l.google.com:19302")]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@app.api_route(
|
||||
"/sessions/{session_id}/{path:path}",
|
||||
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
|
||||
)
|
||||
async def proxy_request(
|
||||
session_id: str, path: str, request: Request, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Mimic Pipecat Cloud's proxy."""
|
||||
active_session = active_sessions.get(session_id)
|
||||
if not active_session:
|
||||
return Response(content="Invalid or not-yet-ready session_id", status_code=404)
|
||||
|
||||
if path.endswith("api/offer"):
|
||||
# Parse the request body and convert to SmallWebRTCRequest
|
||||
try:
|
||||
request_data = await request.json()
|
||||
if request.method == HTTPMethod.POST.value:
|
||||
webrtc_request = SmallWebRTCRequest(
|
||||
sdp=request_data["sdp"],
|
||||
type=request_data["type"],
|
||||
pc_id=request_data.get("pc_id"),
|
||||
restart_pc=request_data.get("restart_pc"),
|
||||
request_data=request_data,
|
||||
)
|
||||
return await offer(webrtc_request, background_tasks)
|
||||
elif request.method == HTTPMethod.PATCH.value:
|
||||
patch_request = SmallWebRTCPatchRequest(
|
||||
pc_id=request_data["pc_id"],
|
||||
candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])],
|
||||
)
|
||||
return await ice_candidate(patch_request)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse WebRTC request: {e}")
|
||||
return Response(content="Invalid WebRTC request", status_code=400)
|
||||
|
||||
logger.info(f"Received request for path: {path}")
|
||||
return Response(status_code=200)
|
||||
|
||||
@asynccontextmanager
|
||||
async def smallwebrtc_lifespan(app: FastAPI):
|
||||
"""Manage FastAPI application lifecycle and cleanup connections."""
|
||||
@@ -258,6 +373,29 @@ def _add_lifespan_to_app(app: FastAPI, new_lifespan):
|
||||
|
||||
def _setup_whatsapp_routes(app: FastAPI):
|
||||
"""Set up WebRTC-specific routes."""
|
||||
WHATSAPP_APP_SECRET = os.getenv("WHATSAPP_APP_SECRET")
|
||||
WHATSAPP_PHONE_NUMBER_ID = os.getenv("WHATSAPP_PHONE_NUMBER_ID")
|
||||
WHATSAPP_TOKEN = os.getenv("WHATSAPP_TOKEN")
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN = os.getenv("WHATSAPP_WEBHOOK_VERIFICATION_TOKEN")
|
||||
|
||||
if not all(
|
||||
[
|
||||
WHATSAPP_APP_SECRET,
|
||||
WHATSAPP_PHONE_NUMBER_ID,
|
||||
WHATSAPP_TOKEN,
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN,
|
||||
]
|
||||
):
|
||||
logger.error(
|
||||
"""Missing required environment variables for WhatsApp transport:
|
||||
WHATSAPP_APP_SECRET
|
||||
WHATSAPP_PHONE_NUMBER_ID
|
||||
WHATSAPP_TOKEN
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
@@ -269,24 +407,7 @@ def _setup_whatsapp_routes(app: FastAPI):
|
||||
from pipecat.transports.whatsapp.api import WhatsAppWebhookRequest
|
||||
from pipecat.transports.whatsapp.client import WhatsAppClient
|
||||
except ImportError as e:
|
||||
logger.error(f"WebRTC transport dependencies not installed: {e}")
|
||||
return
|
||||
|
||||
WHATSAPP_TOKEN = os.getenv("WHATSAPP_TOKEN")
|
||||
WHATSAPP_PHONE_NUMBER_ID = os.getenv("WHATSAPP_PHONE_NUMBER_ID")
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN = os.getenv("WHATSAPP_WEBHOOK_VERIFICATION_TOKEN")
|
||||
WHATSAPP_APP_SECRET = os.getenv("WHATSAPP_APP_SECRET")
|
||||
|
||||
if not all(
|
||||
[
|
||||
WHATSAPP_TOKEN,
|
||||
WHATSAPP_PHONE_NUMBER_ID,
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN,
|
||||
]
|
||||
):
|
||||
logger.debug(
|
||||
"Missing required environment variables for WhatsApp transport. Keeping it disabled."
|
||||
)
|
||||
logger.error(f"WhatsApp transport dependencies not installed: {e}")
|
||||
return
|
||||
|
||||
# Global WhatsApp client instance
|
||||
@@ -456,8 +577,6 @@ def _setup_daily_routes(app: FastAPI):
|
||||
else:
|
||||
logger.debug("No body data provided in request")
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pipecat.runner.daily import configure
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -489,7 +608,7 @@ def _setup_daily_routes(app: FastAPI):
|
||||
return await _handle_rtvi_request(request)
|
||||
|
||||
|
||||
def _setup_telephony_routes(app: FastAPI, transport_type: str, proxy: str):
|
||||
def _setup_telephony_routes(app: FastAPI, *, transport_type: str, proxy: str):
|
||||
"""Set up telephony-specific routes."""
|
||||
# XML response templates (Exotel doesn't use XML webhooks)
|
||||
XML_TEMPLATES = {
|
||||
@@ -545,8 +664,6 @@ def _setup_telephony_routes(app: FastAPI, transport_type: str, proxy: str):
|
||||
async def _run_daily_direct():
|
||||
"""Run Daily bot with direct connection (no FastAPI server)."""
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
from pipecat.runner.daily import configure
|
||||
except ImportError as e:
|
||||
logger.error("Daily transport dependencies not installed.")
|
||||
@@ -592,6 +709,21 @@ def _validate_and_clean_proxy(proxy: str) -> str:
|
||||
return proxy
|
||||
|
||||
|
||||
def runner_downloads_folder() -> Optional[str]:
|
||||
"""Returns the folder where files are stored for later download."""
|
||||
return RUNNER_DOWNLOADS_FOLDER
|
||||
|
||||
|
||||
def runner_host() -> str:
|
||||
"""Returns the host name of this runner."""
|
||||
return RUNNER_HOST
|
||||
|
||||
|
||||
def runner_port() -> int:
|
||||
"""Returns the port of this runner."""
|
||||
return RUNNER_PORT
|
||||
|
||||
|
||||
def main():
|
||||
"""Start the Pipecat development runner.
|
||||
|
||||
@@ -612,14 +744,16 @@ def main():
|
||||
|
||||
The bot file must contain a `bot(runner_args)` function as the entry point.
|
||||
"""
|
||||
global RUNNER_DOWNLOADS_FOLDER, RUNNER_HOST, RUNNER_PORT
|
||||
|
||||
parser = argparse.ArgumentParser(description="Pipecat Development Runner")
|
||||
parser.add_argument("--host", type=str, default="localhost", help="Host address")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port number")
|
||||
parser.add_argument("--host", type=str, default=RUNNER_HOST, help="Host address")
|
||||
parser.add_argument("--port", type=int, default=RUNNER_PORT, help="Port number")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--transport",
|
||||
type=str,
|
||||
choices=["daily", "webrtc", "twilio", "telnyx", "plivo", "exotel"],
|
||||
choices=["daily", "webrtc", *TELEPHONY_TRANSPORTS],
|
||||
default="webrtc",
|
||||
help="Transport type",
|
||||
)
|
||||
@@ -637,9 +771,16 @@ def main():
|
||||
default=False,
|
||||
help="Connect directly to Daily room (automatically sets transport to daily)",
|
||||
)
|
||||
parser.add_argument("-f", "--folder", type=str, help="Path to downloads folder")
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whatsapp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Ensure requried WhatsApp environment variables are present",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -659,6 +800,10 @@ def main():
|
||||
logger.error("For ESP32, you need to specify `--host IP` so we can do SDP munging.")
|
||||
return
|
||||
|
||||
if args.transport in TELEPHONY_TRANSPORTS and not args.proxy:
|
||||
logger.error(f"For telephony transports, you need to specify `--proxy PROXY`.")
|
||||
return
|
||||
|
||||
# Log level
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level="TRACE" if args.verbose else "DEBUG")
|
||||
@@ -678,10 +823,11 @@ def main():
|
||||
print()
|
||||
if args.esp32:
|
||||
print(f"🚀 Bot ready! (ESP32 mode)")
|
||||
print(f" → Open http://{args.host}:{args.port}/client in your browser")
|
||||
elif args.whatsapp:
|
||||
print(f"🚀 Bot ready! (WhatsApp)")
|
||||
else:
|
||||
print(f"🚀 Bot ready!")
|
||||
print(f" → Open http://{args.host}:{args.port}/client in your browser")
|
||||
print(f" → Open http://{args.host}:{args.port}/client in your browser")
|
||||
print()
|
||||
elif args.transport == "daily":
|
||||
print()
|
||||
@@ -689,8 +835,19 @@ def main():
|
||||
print(f" → Open http://{args.host}:{args.port} in your browser to start a session")
|
||||
print()
|
||||
|
||||
RUNNER_DOWNLOADS_FOLDER = args.folder
|
||||
RUNNER_HOST = args.host
|
||||
RUNNER_PORT = args.port
|
||||
|
||||
# Create the app with transport-specific setup
|
||||
app = _create_server_app(args.transport, args.host, args.proxy, args.esp32)
|
||||
app = _create_server_app(
|
||||
transport_type=args.transport,
|
||||
host=args.host,
|
||||
proxy=args.proxy,
|
||||
esp32_mode=args.esp32,
|
||||
whatsapp_enabled=args.whatsapp,
|
||||
folder=args.folder,
|
||||
)
|
||||
|
||||
# Run the server
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -25,11 +25,31 @@ except ModuleNotFoundError as e:
|
||||
class LivekitFrameSerializer(FrameSerializer):
|
||||
"""Serializer for converting between Pipecat frames and LiveKit audio frames.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
|
||||
This class is deprecated and will be removed in a future version.
|
||||
Please use LiveKitTransport instead, which handles audio streaming
|
||||
and frame conversion natively.
|
||||
|
||||
This serializer handles the conversion of Pipecat's OutputAudioRawFrame objects
|
||||
to LiveKit AudioFrame objects for transmission, and the reverse conversion
|
||||
for received audio data.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the LiveKit frame serializer."""
|
||||
super().__init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"LivekitFrameSerializer is deprecated and will be removed in a future version. "
|
||||
"Please use LiveKitTransport instead, which handles audio streaming natively.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Get the serializer type.
|
||||
|
||||
@@ -97,9 +97,7 @@ class AIService(FrameProcessor):
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
from pipecat.services.openai_realtime_beta.events import (
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai.realtime.events import SessionProperties
|
||||
|
||||
for key, value in settings.items():
|
||||
logger.debug("Update request for:", key, value)
|
||||
@@ -111,9 +109,7 @@ class AIService(FrameProcessor):
|
||||
logger.debug("Attempting to update", key, value)
|
||||
|
||||
try:
|
||||
from pipecat.services.openai_realtime_beta.events import (
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.services.openai.realtime.events import TurnDetection
|
||||
|
||||
if isinstance(self._session_properties, SessionProperties):
|
||||
current_properties = self._session_properties
|
||||
|
||||
@@ -108,6 +108,8 @@ class AssemblyAIConnectionParams(BaseModel):
|
||||
end_of_turn_confidence_threshold: Confidence threshold for end-of-turn detection.
|
||||
min_end_of_turn_silence_when_confident: Minimum silence duration when confident about end-of-turn.
|
||||
max_turn_silence: Maximum silence duration before forcing end-of-turn.
|
||||
keyterms_prompt: List of key terms to guide transcription. Will be JSON serialized before sending.
|
||||
speech_model: Select between English and multilingual models. Defaults to "universal-streaming-english".
|
||||
"""
|
||||
|
||||
sample_rate: int = 16000
|
||||
@@ -117,3 +119,7 @@ class AssemblyAIConnectionParams(BaseModel):
|
||||
end_of_turn_confidence_threshold: Optional[float] = None
|
||||
min_end_of_turn_silence_when_confident: Optional[int] = None
|
||||
max_turn_silence: Optional[int] = None
|
||||
keyterms_prompt: Optional[List[str]] = None
|
||||
speech_model: Literal["universal-streaming-english", "universal-streaming-multilingual"] = (
|
||||
"universal-streaming-english"
|
||||
)
|
||||
|
||||
@@ -174,11 +174,16 @@ class AssemblyAISTTService(STTService):
|
||||
|
||||
def _build_ws_url(self) -> str:
|
||||
"""Build WebSocket URL with query parameters using urllib.parse.urlencode."""
|
||||
params = {
|
||||
k: str(v).lower() if isinstance(v, bool) else v
|
||||
for k, v in self._connection_params.model_dump().items()
|
||||
if v is not None
|
||||
}
|
||||
params = {}
|
||||
for k, v in self._connection_params.model_dump().items():
|
||||
if v is not None:
|
||||
if k == "keyterms_prompt":
|
||||
params[k] = json.dumps(v)
|
||||
elif isinstance(v, bool):
|
||||
params[k] = str(v).lower()
|
||||
else:
|
||||
params[k] = v
|
||||
|
||||
if params:
|
||||
query_string = urlencode(params)
|
||||
return f"{self._api_endpoint_base_url}?{query_string}"
|
||||
@@ -197,6 +202,8 @@ class AssemblyAISTTService(STTService):
|
||||
)
|
||||
self._connected = True
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to AssemblyAI: {e}")
|
||||
self._connected = False
|
||||
@@ -238,6 +245,7 @@ class AssemblyAISTTService(STTService):
|
||||
self._websocket = None
|
||||
self._connected = False
|
||||
self._receive_task = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
|
||||
@@ -235,6 +235,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
}
|
||||
|
||||
await self._get_websocket().send(json.dumps(init_msg))
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -252,6 +254,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._started = False
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -9,6 +9,7 @@ import sys
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .nova_sonic import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
|
||||
0
src/pipecat/services/aws/nova_sonic/__init__.py
Normal file
0
src/pipecat/services/aws/nova_sonic/__init__.py
Normal file
87
src/pipecat/services/aws/nova_sonic/context.py
Normal file
87
src/pipecat/services/aws/nova_sonic/context.py
Normal file
@@ -0,0 +1,87 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Context management for AWS Nova Sonic LLM service.
|
||||
|
||||
This module provides specialized context aggregators and message handling for AWS Nova Sonic,
|
||||
including conversation history management and role-specific message processing.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
AWS Nova Sonic now supports `LLMContext` and `LLMContextAggregatorPair`.
|
||||
Using the new patterns should allow you to not need types from this module.
|
||||
|
||||
BEFORE:
|
||||
```
|
||||
# Setup
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Context frame type
|
||||
frame: OpenAILLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: AWSNovaSonicLLMContext
|
||||
# or
|
||||
context: OpenAILLMContext
|
||||
|
||||
# Reading messages from context
|
||||
messages = context.messages
|
||||
```
|
||||
|
||||
AFTER:
|
||||
```
|
||||
# Setup
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Context frame type
|
||||
frame: LLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: LLMContext
|
||||
|
||||
# Reading messages from context
|
||||
messages = context.get_messages()
|
||||
```
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.aws.nova_sonic.context are deprecated. \n"
|
||||
"AWS Nova Sonic now supports `LLMContext` and `LLMContextAggregatorPair`. \n"
|
||||
"Using the new patterns should allow you to not need types from this module.\n\n"
|
||||
"BEFORE:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = OpenAILLMContext(messages, tools)\n"
|
||||
"context_aggregator = llm.create_context_aggregator(context)\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: OpenAILLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: AWSNovaSonicLLMContext\n"
|
||||
"# or\n"
|
||||
"context: OpenAILLMContext\n\n"
|
||||
"# Reading messages from context\n"
|
||||
"messages = context.messages\n"
|
||||
"```\n\n"
|
||||
"AFTER:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = LLMContext(messages, tools)\n"
|
||||
"context_aggregator = LLMContextAggregatorPair(context)\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: LLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: LLMContext\n\n"
|
||||
"# Reading messages from context\n"
|
||||
"messages = context.messages\n"
|
||||
"```",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
25
src/pipecat/services/aws/nova_sonic/frames.py
Normal file
25
src/pipecat/services/aws/nova_sonic/frames.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frames for AWS Nova Sonic LLM service."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call result for AWS Nova Sonic processing.
|
||||
|
||||
This frame wraps a standard function call result frame to enable
|
||||
AWS Nova Sonic-specific handling and context updates.
|
||||
|
||||
Parameters:
|
||||
result_frame: The underlying function call result frame.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
1265
src/pipecat/services/aws/nova_sonic/llm.py
Normal file
1265
src/pipecat/services/aws/nova_sonic/llm.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -286,6 +286,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
await self._disconnect()
|
||||
@@ -310,6 +311,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
self._ws_client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert internal language enum to AWS Transcribe language code.
|
||||
|
||||
@@ -1 +1,19 @@
|
||||
from .aws import AWSNovaSonicLLMService, Params
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService, Params
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.aws_nova_sonic are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.aws.nova_sonic.llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,360 +8,80 @@
|
||||
|
||||
This module provides specialized context aggregators and message handling for AWS Nova Sonic,
|
||||
including conversation history management and role-specific message processing.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
AWS Nova Sonic now supports `LLMContext` and `LLMContextAggregatorPair`.
|
||||
Using the new patterns should allow you to not need types from this module.
|
||||
|
||||
BEFORE:
|
||||
```
|
||||
# Setup
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Context frame type
|
||||
frame: OpenAILLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: AWSNovaSonicLLMContext
|
||||
# or
|
||||
context: OpenAILLMContext
|
||||
|
||||
# Reading messages from context
|
||||
messages = context.messages
|
||||
```
|
||||
|
||||
AFTER:
|
||||
```
|
||||
# Setup
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Context frame type
|
||||
frame: LLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: LLMContext
|
||||
|
||||
# Reading messages from context
|
||||
messages = context.get_messages()
|
||||
```
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""Roles supported in AWS Nova Sonic conversations.
|
||||
|
||||
Parameters:
|
||||
SYSTEM: System-level messages (not used in conversation history).
|
||||
USER: Messages sent by the user.
|
||||
ASSISTANT: Messages sent by the assistant.
|
||||
TOOL: Messages sent by tools (not used in conversation history).
|
||||
"""
|
||||
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistoryMessage:
|
||||
"""A single message in AWS Nova Sonic conversation history.
|
||||
|
||||
Parameters:
|
||||
role: The role of the message sender (USER or ASSISTANT only).
|
||||
text: The text content of the message.
|
||||
"""
|
||||
|
||||
role: Role # only USER and ASSISTANT
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistory:
|
||||
"""Complete conversation history for AWS Nova Sonic initialization.
|
||||
|
||||
Parameters:
|
||||
system_instruction: System-level instruction for the conversation.
|
||||
messages: List of conversation messages between user and assistant.
|
||||
"""
|
||||
|
||||
system_instruction: str = None
|
||||
messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMContext(OpenAILLMContext):
|
||||
"""Specialized LLM context for AWS Nova Sonic service.
|
||||
|
||||
Extends OpenAI context with Nova Sonic-specific message handling,
|
||||
conversation history management, and text buffering capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize AWS Nova Sonic LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial messages for the context.
|
||||
tools: Available tools for the context.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self, system_instruction: str = ""):
|
||||
self._assistant_text = ""
|
||||
self._user_text = ""
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_nova_sonic(
|
||||
obj: OpenAILLMContext, system_instruction: str
|
||||
) -> "AWSNovaSonicLLMContext":
|
||||
"""Upgrade an OpenAI context to AWS Nova Sonic context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAI context to upgrade.
|
||||
system_instruction: System instruction for the context.
|
||||
|
||||
Returns:
|
||||
The upgraded AWS Nova Sonic context.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext):
|
||||
obj.__class__ = AWSNovaSonicLLMContext
|
||||
obj.__setup_local(system_instruction)
|
||||
return obj
|
||||
|
||||
# NOTE: this method has the side-effect of updating _system_instruction from messages
|
||||
def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory:
|
||||
"""Get conversation history for initializing AWS Nova Sonic session.
|
||||
|
||||
Processes stored messages and extracts system instruction and conversation
|
||||
history in the format expected by AWS Nova Sonic.
|
||||
|
||||
Returns:
|
||||
Formatted conversation history with system instruction and messages.
|
||||
"""
|
||||
history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction)
|
||||
|
||||
# Bail if there are no messages
|
||||
if not self.messages:
|
||||
return history
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into "instruction"
|
||||
if messages[0].get("role") == "system":
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
history.system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
history.system_instruction = content[0].get("text")
|
||||
if history.system_instruction:
|
||||
self._system_instruction = history.system_instruction
|
||||
|
||||
# Process remaining messages to fill out conversation history.
|
||||
# Nova Sonic supports "user" and "assistant" messages in history.
|
||||
for message in messages:
|
||||
history_message = self.from_standard_message(message)
|
||||
if history_message:
|
||||
history.messages.append(history_message)
|
||||
|
||||
return history
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Returns:
|
||||
List of messages including system instruction if present.
|
||||
"""
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
# If we have a system instruction and messages doesn't already contain it, add it
|
||||
if self._system_instruction and not (messages and messages[0].get("role") == "system"):
|
||||
messages.insert(0, {"role": "system", "content": self._system_instruction})
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
|
||||
"""Convert standard message format to Nova Sonic format.
|
||||
|
||||
Args:
|
||||
message: Standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Nova Sonic conversation history message, or None if not convertible.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if message.get("role") == "user" or message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
# There won't be content if this is an assistant tool call entry.
|
||||
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
|
||||
# history
|
||||
if content:
|
||||
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
|
||||
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
|
||||
# Sonic conversation history
|
||||
|
||||
def buffer_user_text(self, text):
|
||||
"""Buffer user text for later flushing to context.
|
||||
|
||||
Args:
|
||||
text: User text to buffer.
|
||||
"""
|
||||
self._user_text += f" {text}" if self._user_text else text
|
||||
# logger.debug(f"User text buffered: {self._user_text}")
|
||||
|
||||
def flush_aggregated_user_text(self) -> str:
|
||||
"""Flush buffered user text to context as a complete message.
|
||||
|
||||
Returns:
|
||||
The flushed user text, or empty string if no text was buffered.
|
||||
"""
|
||||
if not self._user_text:
|
||||
return ""
|
||||
user_text = self._user_text
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}],
|
||||
}
|
||||
self._user_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (user): {self.get_messages_for_logging()}")
|
||||
return user_text
|
||||
|
||||
def buffer_assistant_text(self, text):
|
||||
"""Buffer assistant text for later flushing to context.
|
||||
|
||||
Args:
|
||||
text: Assistant text to buffer.
|
||||
"""
|
||||
self._assistant_text += text
|
||||
# logger.debug(f"Assistant text buffered: {self._assistant_text}")
|
||||
|
||||
def flush_aggregated_assistant_text(self):
|
||||
"""Flush buffered assistant text to context as a complete message."""
|
||||
if not self._assistant_text:
|
||||
return
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": self._assistant_text}],
|
||||
}
|
||||
self._assistant_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicMessagesUpdateFrame(DataFrame):
|
||||
"""Frame containing updated AWS Nova Sonic context.
|
||||
|
||||
Parameters:
|
||||
context: The updated AWS Nova Sonic LLM context.
|
||||
"""
|
||||
|
||||
context: AWSNovaSonicLLMContext
|
||||
|
||||
|
||||
class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""Context aggregator for user messages in AWS Nova Sonic conversations.
|
||||
|
||||
Extends the OpenAI user context aggregator to emit Nova Sonic-specific
|
||||
context update frames.
|
||||
"""
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process frames and emit Nova Sonic-specific context updates.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction the frame is traveling.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Parent does not push LLMMessagesUpdateFrame
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context))
|
||||
|
||||
|
||||
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Context aggregator for assistant messages in AWS Nova Sonic conversations.
|
||||
|
||||
Provides specialized handling for assistant responses and function calls
|
||||
in AWS Nova Sonic context, with custom frame processing logic.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Nova Sonic-specific logic.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction the frame is traveling.
|
||||
"""
|
||||
# HACK: For now, disable the context aggregator by making it just pass through all frames
|
||||
# that the parent handles (except the function call stuff, which we still need).
|
||||
# For an explanation of this hack, see
|
||||
# AWSNovaSonicLLMService._report_assistant_response_text_added.
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
TextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
UserImageRawFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
),
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call results for AWS Nova Sonic.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the LLM
|
||||
# itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side
|
||||
# context. Let's push a special frame to do that.
|
||||
await self.push_frame(
|
||||
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicContextAggregatorPair:
|
||||
"""Pair of user and assistant context aggregators for AWS Nova Sonic.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator.
|
||||
_assistant: The assistant context aggregator.
|
||||
"""
|
||||
|
||||
_user: AWSNovaSonicUserContextAggregator
|
||||
_assistant: AWSNovaSonicAssistantContextAggregator
|
||||
|
||||
def user(self) -> AWSNovaSonicUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> AWSNovaSonicAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.aws_nova_sonic.context are deprecated. \n"
|
||||
"AWS Nova Sonic now supports `LLMContext` and `LLMContextAggregatorPair`. \n"
|
||||
"Using the new patterns should allow you to not need types from this module.\n\n"
|
||||
"BEFORE:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = OpenAILLMContext(messages, tools)\n"
|
||||
"context_aggregator = llm.create_context_aggregator(context)\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: OpenAILLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: AWSNovaSonicLLMContext\n"
|
||||
"# or\n"
|
||||
"context: OpenAILLMContext\n\n"
|
||||
"# Reading messages from context\n"
|
||||
"messages = context.messages\n"
|
||||
"```\n\n"
|
||||
"AFTER:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = LLMContext(messages, tools)\n"
|
||||
"context_aggregator = LLMContextAggregatorPair(context)\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: LLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: LLMContext\n\n"
|
||||
"# Reading messages from context\n"
|
||||
"messages = context.messages\n"
|
||||
"```",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -6,20 +6,16 @@
|
||||
|
||||
"""Custom frames for AWS Nova Sonic LLM service."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
from pipecat.services.aws.nova_sonic.frames import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call result for AWS Nova Sonic processing.
|
||||
|
||||
This frame wraps a standard function call result frame to enable
|
||||
AWS Nova Sonic-specific handling and context updates.
|
||||
|
||||
Parameters:
|
||||
result_frame: The underlying function call result frame.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.aws_nova_sonic.frames are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.aws.nova_sonic.frames instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
0
src/pipecat/services/azure/realtime/__init__.py
Normal file
0
src/pipecat/services/azure/realtime/__init__.py
Normal file
65
src/pipecat/services/azure/realtime/llm.py
Normal file
65
src/pipecat/services/azure/realtime/llm.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Azure OpenAI Realtime LLM service implementation."""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Azure Realtime, you need to `pip install pipecat-ai[openai]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AzureRealtimeLLMService(OpenAIRealtimeLLMService):
|
||||
"""Azure OpenAI Realtime LLM service with Azure-specific authentication.
|
||||
|
||||
Extends the OpenAI Realtime service to work with Azure OpenAI endpoints,
|
||||
using Azure's authentication headers and endpoint format. Provides the same
|
||||
real-time audio and text communication capabilities as the base OpenAI service.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Azure Realtime LLM service.
|
||||
|
||||
Args:
|
||||
api_key: The API key for the Azure OpenAI service.
|
||||
base_url: The full Azure WebSocket endpoint URL including api-version and deployment.
|
||||
Example: "wss://my-project.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-realtime-deployment"
|
||||
**kwargs: Additional arguments passed to parent OpenAIRealtimeLLMService.
|
||||
"""
|
||||
super().__init__(base_url=base_url, api_key=api_key, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
if self._websocket:
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
|
||||
logger.info(f"Connecting to {self.base_url}, api key: {self.api_key}")
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self.base_url,
|
||||
additional_headers={
|
||||
"api-key": self.api_key,
|
||||
},
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -28,13 +28,12 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
@@ -124,7 +123,7 @@ class CartesiaLiveOptions:
|
||||
return cls(**json.loads(json_str))
|
||||
|
||||
|
||||
class CartesiaSTTService(STTService):
|
||||
class CartesiaSTTService(WebsocketSTTService):
|
||||
"""Speech-to-text service using Cartesia Live API.
|
||||
|
||||
Provides real-time speech transcription through WebSocket connection
|
||||
@@ -176,8 +175,7 @@ class CartesiaSTTService(STTService):
|
||||
self.set_model_name(merged_options.model)
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url or "api.cartesia.ai"
|
||||
self._connection = None
|
||||
self._receiver_task = None
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate processing metrics.
|
||||
@@ -214,6 +212,27 @@ class CartesiaSTTService(STTService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start performance metrics collection for transcription processing."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and handle speech events.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: Direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Send finalize command to flush the transcription session
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send("finalize")
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
@@ -224,45 +243,71 @@ class CartesiaSTTService(STTService):
|
||||
None - transcription results are handled via WebSocket responses.
|
||||
"""
|
||||
# If the connection is closed, due to timeout, we need to reconnect when the user starts speaking again
|
||||
if not self._connection or self._connection.state is State.CLOSED:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
await self._connection.send(audio)
|
||||
await self._websocket.send(audio)
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
params = self._settings.to_dict()
|
||||
ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}"
|
||||
logger.debug(f"Connecting to Cartesia: {ws_url}")
|
||||
headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = asyncio.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
self._connection = await websocket_connect(ws_url, additional_headers=headers)
|
||||
# Setup the receiver task to handle the incoming messages from the Cartesia server
|
||||
if self._receiver_task is None or self._receiver_task.done():
|
||||
self._receiver_task = asyncio.create_task(self._receive_messages())
|
||||
logger.debug(f"Connected to Cartesia")
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
logger.debug("Connecting to Cartesia STT")
|
||||
|
||||
params = self._settings.to_dict()
|
||||
ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}"
|
||||
headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}
|
||||
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to connect to Cartesia: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
while True:
|
||||
if not self._connection or self._connection.state is State.CLOSED:
|
||||
break
|
||||
|
||||
message = await self._connection.recv()
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.debug(f"WebSocket connection closed: {e}")
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Cartesia STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message receiver: {e}")
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
await self._process_messages()
|
||||
# Cartesia times out after 5 minutes of innactivity (no keepalive
|
||||
# mechanism is available). So, we try to reconnect.
|
||||
logger.debug(f"{self} Cartesia connection was disconnected (timeout?), reconnecting")
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _process_response(self, data):
|
||||
if "type" in data:
|
||||
@@ -316,41 +361,3 @@ class CartesiaSTTService(STTService):
|
||||
language,
|
||||
)
|
||||
)
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receiver_task:
|
||||
self._receiver_task.cancel()
|
||||
try:
|
||||
await self._receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected exception while cancelling task: {e}")
|
||||
self._receiver_task = None
|
||||
|
||||
if self._connection and self._connection.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start performance metrics collection for transcription processing."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and handle speech events.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: Direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Send finalize command to flush the transcription session
|
||||
if self._connection and self._connection.state is State.OPEN:
|
||||
await self._connection.send("finalize")
|
||||
|
||||
@@ -344,10 +344,11 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
logger.debug("Connecting to Cartesia")
|
||||
logger.debug("Connecting to Cartesia TTS")
|
||||
self._websocket = await websocket_connect(
|
||||
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -365,6 +366,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -205,6 +205,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -225,6 +226,9 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_close_stream(self) -> None:
|
||||
"""Sends a CloseStream control message to the Deepgram Flux WebSocket API.
|
||||
|
||||
@@ -168,16 +168,24 @@ def build_elevenlabs_voice_settings(
|
||||
|
||||
|
||||
def calculate_word_times(
|
||||
alignment_info: Mapping[str, Any], cumulative_time: float
|
||||
) -> List[Tuple[str, float]]:
|
||||
alignment_info: Mapping[str, Any],
|
||||
cumulative_time: float,
|
||||
partial_word: str = "",
|
||||
partial_word_start_time: float = 0.0,
|
||||
) -> tuple[List[Tuple[str, float]], str, float]:
|
||||
"""Calculate word timestamps from character alignment information.
|
||||
|
||||
Args:
|
||||
alignment_info: Character alignment data from ElevenLabs API.
|
||||
cumulative_time: Base time offset for this chunk.
|
||||
partial_word: Partial word carried over from previous chunk.
|
||||
partial_word_start_time: Start time of the partial word.
|
||||
|
||||
Returns:
|
||||
List of (word, timestamp) tuples.
|
||||
Tuple of (word_times, new_partial_word, new_partial_word_start_time):
|
||||
- word_times: List of (word, timestamp) tuples for complete words
|
||||
- new_partial_word: Incomplete word at end of chunk (empty if chunk ends with space)
|
||||
- new_partial_word_start_time: Start time of the incomplete word
|
||||
"""
|
||||
chars = alignment_info["chars"]
|
||||
char_start_times_ms = alignment_info["charStartTimesMs"]
|
||||
@@ -186,41 +194,37 @@ def calculate_word_times(
|
||||
logger.error(
|
||||
f"calculate_word_times: length mismatch - chars={len(chars)}, times={len(char_start_times_ms)}"
|
||||
)
|
||||
return []
|
||||
return ([], partial_word, partial_word_start_time)
|
||||
|
||||
# Build words and track their start positions
|
||||
words = []
|
||||
word_start_indices = []
|
||||
current_word = ""
|
||||
word_start_index = None
|
||||
word_start_times = []
|
||||
current_word = partial_word # Start with any partial word from previous chunk
|
||||
word_start_time = partial_word_start_time if partial_word else None
|
||||
|
||||
for i, char in enumerate(chars):
|
||||
if char == " ":
|
||||
# End of current word
|
||||
if current_word: # Only add non-empty words
|
||||
words.append(current_word)
|
||||
word_start_indices.append(word_start_index)
|
||||
word_start_times.append(word_start_time)
|
||||
current_word = ""
|
||||
word_start_index = None
|
||||
word_start_time = None
|
||||
else:
|
||||
# Building a word
|
||||
if word_start_index is None: # First character of new word
|
||||
word_start_index = i
|
||||
if word_start_time is None: # First character of new word
|
||||
# Convert from milliseconds to seconds and add cumulative offset
|
||||
word_start_time = cumulative_time + (char_start_times_ms[i] / 1000.0)
|
||||
current_word += char
|
||||
|
||||
# Handle the last word if there's no trailing space
|
||||
if current_word and word_start_index is not None:
|
||||
words.append(current_word)
|
||||
word_start_indices.append(word_start_index)
|
||||
# Build result for complete words
|
||||
word_times = list(zip(words, word_start_times))
|
||||
|
||||
# Calculate timestamps for each word
|
||||
word_times = []
|
||||
for word, start_idx in zip(words, word_start_indices):
|
||||
# Convert from milliseconds to seconds and add cumulative offset
|
||||
start_time_seconds = cumulative_time + (char_start_times_ms[start_idx] / 1000.0)
|
||||
word_times.append((word, start_time_seconds))
|
||||
# Return any incomplete word at the end of this chunk
|
||||
new_partial_word = current_word if current_word else ""
|
||||
new_partial_word_start_time = word_start_time if word_start_time is not None else 0.0
|
||||
|
||||
return word_times
|
||||
return (word_times, new_partial_word, new_partial_word_start_time)
|
||||
|
||||
|
||||
class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
@@ -332,6 +336,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
# there's an interruption or TTSStoppedFrame.
|
||||
self._started = False
|
||||
self._cumulative_time = 0
|
||||
# Track partial words that span across alignment chunks
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
# Context management for v1 multi API
|
||||
self._context_id = None
|
||||
@@ -521,6 +528,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
url, max_size=16 * 1024 * 1024, additional_headers={"xi-api-key": self._api_key}
|
||||
)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -543,6 +551,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
@@ -570,6 +579,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Handle incoming WebSocket messages from ElevenLabs."""
|
||||
@@ -609,7 +620,14 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
if msg.get("alignment"):
|
||||
alignment = msg["alignment"]
|
||||
word_times = calculate_word_times(alignment, self._cumulative_time)
|
||||
word_times, self._partial_word, self._partial_word_start_time = (
|
||||
calculate_word_times(
|
||||
alignment,
|
||||
self._cumulative_time,
|
||||
self._partial_word,
|
||||
self._partial_word_start_time,
|
||||
)
|
||||
)
|
||||
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times)
|
||||
@@ -683,6 +701,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
self._cumulative_time = 0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
# If a context ID does not exist, create a new one and
|
||||
# register it. If an ID exists, that means the Pipeline is
|
||||
# configured for allow_interruptions=False, so continue
|
||||
@@ -756,6 +776,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
base_url: str = "https://api.elevenlabs.io",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs HTTP TTS service.
|
||||
@@ -768,10 +789,11 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
base_url: Base URL for ElevenLabs HTTP API.
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
super().__init__(
|
||||
aggregate_sentences=True,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
@@ -809,6 +831,10 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
# Store previous text for context within a turn
|
||||
self._previous_text = ""
|
||||
|
||||
# Track partial words that span across alignment chunks
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language to ElevenLabs language code.
|
||||
|
||||
@@ -836,6 +862,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
self._cumulative_time = 0
|
||||
self._started = False
|
||||
self._previous_text = ""
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
logger.debug(f"{self}: Reset internal state")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
@@ -870,11 +898,13 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
def calculate_word_times(self, alignment_info: Mapping[str, Any]) -> List[Tuple[str, float]]:
|
||||
"""Calculate word timing from character alignment data.
|
||||
|
||||
This method handles partial words that may span across multiple alignment chunks.
|
||||
|
||||
Args:
|
||||
alignment_info: Character timing data from ElevenLabs.
|
||||
|
||||
Returns:
|
||||
List of (word, timestamp) pairs.
|
||||
List of (word, timestamp) pairs for complete words in this chunk.
|
||||
|
||||
Example input data::
|
||||
|
||||
@@ -900,30 +930,28 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
# Build the words and find their start times
|
||||
words = []
|
||||
word_start_times = []
|
||||
current_word = ""
|
||||
first_char_idx = -1
|
||||
# Start with any partial word from previous chunk
|
||||
current_word = self._partial_word
|
||||
word_start_time = self._partial_word_start_time if self._partial_word else None
|
||||
|
||||
for i, char in enumerate(chars):
|
||||
if char == " ":
|
||||
if current_word: # Only add non-empty words
|
||||
words.append(current_word)
|
||||
# Use time of the first character of the word, offset by cumulative time
|
||||
word_start_times.append(
|
||||
self._cumulative_time + char_start_times[first_char_idx]
|
||||
)
|
||||
word_start_times.append(word_start_time)
|
||||
current_word = ""
|
||||
first_char_idx = -1
|
||||
word_start_time = None
|
||||
else:
|
||||
if not current_word: # This is the first character of a new word
|
||||
first_char_idx = i
|
||||
if word_start_time is None: # First character of a new word
|
||||
# Use time of the first character of the word, offset by cumulative time
|
||||
word_start_time = self._cumulative_time + char_start_times[i]
|
||||
current_word += char
|
||||
|
||||
# Don't forget the last word if there's no trailing space
|
||||
if current_word and first_char_idx >= 0:
|
||||
words.append(current_word)
|
||||
word_start_times.append(self._cumulative_time + char_start_times[first_char_idx])
|
||||
# Store any incomplete word at the end of this chunk
|
||||
self._partial_word = current_word if current_word else ""
|
||||
self._partial_word_start_time = word_start_time if word_start_time is not None else 0.0
|
||||
|
||||
# Create word-time pairs
|
||||
# Create word-time pairs for complete words only
|
||||
word_times = list(zip(words, word_start_times))
|
||||
|
||||
return word_times
|
||||
@@ -959,6 +987,9 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
if self._voice_settings:
|
||||
payload["voice_settings"] = self._voice_settings
|
||||
|
||||
if self._settings["apply_text_normalization"] is not None:
|
||||
payload["apply_text_normalization"] = self._settings["apply_text_normalization"]
|
||||
|
||||
language = self._settings["language"]
|
||||
if self._model_name in ELEVENLABS_MULTILINGUAL_MODELS and language:
|
||||
payload["language_code"] = language
|
||||
@@ -979,8 +1010,6 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
}
|
||||
if self._settings["optimize_streaming_latency"] is not None:
|
||||
params["optimize_streaming_latency"] = self._settings["optimize_streaming_latency"]
|
||||
if self._settings["apply_text_normalization"] is not None:
|
||||
params["apply_text_normalization"] = self._settings["apply_text_normalization"]
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -1041,6 +1070,14 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
logger.error(f"Error processing response: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
# After processing all chunks, emit any remaining partial word
|
||||
# since this is the end of the utterance
|
||||
if self._partial_word:
|
||||
final_word_time = [(self._partial_word, self._partial_word_start_time)]
|
||||
await self.add_word_timestamps(final_word_time)
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
# After processing all chunks, add the total utterance duration
|
||||
# to the cumulative time to ensure next utterance starts after this one
|
||||
if utterance_duration > 0:
|
||||
|
||||
@@ -225,6 +225,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
start_message = {"event": "start", "request": {"text": "", **self._settings}}
|
||||
await self._websocket.send(ormsgpack.packb(start_message))
|
||||
logger.debug("Sent start message to Fish Audio")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Fish Audio initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -245,6 +247,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any buffered audio by sending a flush event to Fish Audio."""
|
||||
|
||||
@@ -4,527 +4,41 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Event models and utilities for Google Gemini Multimodal Live API."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import ImageRawFrame
|
||||
|
||||
#
|
||||
# Client events
|
||||
#
|
||||
|
||||
|
||||
class MediaChunk(BaseModel):
|
||||
"""Represents a chunk of media data for transmission.
|
||||
|
||||
Parameters:
|
||||
mimeType: MIME type of the media content.
|
||||
data: Base64-encoded media data.
|
||||
"""
|
||||
|
||||
mimeType: str
|
||||
data: str
|
||||
|
||||
|
||||
class ContentPart(BaseModel):
|
||||
"""Represents a part of content that can contain text or media.
|
||||
|
||||
Parameters:
|
||||
text: Text content. Defaults to None.
|
||||
inlineData: Inline media data. Defaults to None.
|
||||
"""
|
||||
|
||||
text: Optional[str] = Field(default=None, validate_default=False)
|
||||
inlineData: Optional[MediaChunk] = Field(default=None, validate_default=False)
|
||||
fileData: Optional["FileData"] = Field(default=None, validate_default=False)
|
||||
|
||||
|
||||
class FileData(BaseModel):
|
||||
"""Represents a file reference in the Gemini File API."""
|
||||
|
||||
mimeType: str
|
||||
fileUri: str
|
||||
|
||||
|
||||
ContentPart.model_rebuild() # Rebuild model to resolve forward reference
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
"""Represents a conversational turn in the dialogue.
|
||||
|
||||
Parameters:
|
||||
role: The role of the speaker, either "user" or "model". Defaults to "user".
|
||||
parts: List of content parts that make up the turn.
|
||||
"""
|
||||
|
||||
role: Literal["user", "model"] = "user"
|
||||
parts: List[ContentPart]
|
||||
|
||||
|
||||
class StartSensitivity(str, Enum):
|
||||
"""Determines how start of speech is detected."""
|
||||
|
||||
UNSPECIFIED = "START_SENSITIVITY_UNSPECIFIED" # Default is HIGH
|
||||
HIGH = "START_SENSITIVITY_HIGH" # Detect start of speech more often
|
||||
LOW = "START_SENSITIVITY_LOW" # Detect start of speech less often
|
||||
|
||||
|
||||
class EndSensitivity(str, Enum):
|
||||
"""Determines how end of speech is detected."""
|
||||
|
||||
UNSPECIFIED = "END_SENSITIVITY_UNSPECIFIED" # Default is HIGH
|
||||
HIGH = "END_SENSITIVITY_HIGH" # End speech more often
|
||||
LOW = "END_SENSITIVITY_LOW" # End speech less often
|
||||
|
||||
|
||||
class AutomaticActivityDetection(BaseModel):
|
||||
"""Configures automatic detection of voice activity.
|
||||
|
||||
Parameters:
|
||||
disabled: Whether automatic activity detection is disabled. Defaults to None.
|
||||
start_of_speech_sensitivity: Sensitivity for detecting speech start. Defaults to None.
|
||||
prefix_padding_ms: Padding before speech start in milliseconds. Defaults to None.
|
||||
end_of_speech_sensitivity: Sensitivity for detecting speech end. Defaults to None.
|
||||
silence_duration_ms: Duration of silence to detect speech end. Defaults to None.
|
||||
"""
|
||||
|
||||
disabled: Optional[bool] = None
|
||||
start_of_speech_sensitivity: Optional[StartSensitivity] = None
|
||||
prefix_padding_ms: Optional[int] = None
|
||||
end_of_speech_sensitivity: Optional[EndSensitivity] = None
|
||||
silence_duration_ms: Optional[int] = None
|
||||
|
||||
|
||||
class RealtimeInputConfig(BaseModel):
|
||||
"""Configures the realtime input behavior.
|
||||
|
||||
Parameters:
|
||||
automatic_activity_detection: Voice activity detection configuration. Defaults to None.
|
||||
"""
|
||||
|
||||
automatic_activity_detection: Optional[AutomaticActivityDetection] = None
|
||||
|
||||
|
||||
class RealtimeInput(BaseModel):
|
||||
"""Contains realtime input media chunks and text.
|
||||
|
||||
Parameters:
|
||||
mediaChunks: List of media chunks for realtime processing.
|
||||
text: Text for realtime processing.
|
||||
"""
|
||||
|
||||
mediaChunks: Optional[List[MediaChunk]] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class ClientContent(BaseModel):
|
||||
"""Content sent from client to the Gemini Live API.
|
||||
|
||||
Parameters:
|
||||
turns: List of conversation turns. Defaults to None.
|
||||
turnComplete: Whether the client's turn is complete. Defaults to False.
|
||||
"""
|
||||
|
||||
turns: Optional[List[Turn]] = None
|
||||
turnComplete: bool = False
|
||||
|
||||
|
||||
class AudioInputMessage(BaseModel):
|
||||
"""Message containing audio input data.
|
||||
|
||||
Parameters:
|
||||
realtimeInput: Realtime input containing audio chunks.
|
||||
"""
|
||||
|
||||
realtimeInput: RealtimeInput
|
||||
|
||||
@classmethod
|
||||
def from_raw_audio(cls, raw_audio: bytes, sample_rate: int) -> "AudioInputMessage":
|
||||
"""Create an audio input message from raw audio data.
|
||||
|
||||
Args:
|
||||
raw_audio: Raw audio bytes.
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
|
||||
Returns:
|
||||
AudioInputMessage instance with encoded audio data.
|
||||
"""
|
||||
data = base64.b64encode(raw_audio).decode("utf-8")
|
||||
return cls(
|
||||
realtimeInput=RealtimeInput(
|
||||
mediaChunks=[MediaChunk(mimeType=f"audio/pcm;rate={sample_rate}", data=data)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class VideoInputMessage(BaseModel):
|
||||
"""Message containing video/image input data.
|
||||
|
||||
Parameters:
|
||||
realtimeInput: Realtime input containing video/image chunks.
|
||||
"""
|
||||
|
||||
realtimeInput: RealtimeInput
|
||||
|
||||
@classmethod
|
||||
def from_image_frame(cls, frame: ImageRawFrame) -> "VideoInputMessage":
|
||||
"""Create a video input message from an image frame.
|
||||
|
||||
Args:
|
||||
frame: Image frame to encode.
|
||||
|
||||
Returns:
|
||||
VideoInputMessage instance with encoded image data.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG")
|
||||
data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return cls(
|
||||
realtimeInput=RealtimeInput(mediaChunks=[MediaChunk(mimeType=f"image/jpeg", data=data)])
|
||||
)
|
||||
|
||||
|
||||
class TextInputMessage(BaseModel):
|
||||
"""Message containing text input data."""
|
||||
|
||||
realtimeInput: RealtimeInput
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str) -> "TextInputMessage":
|
||||
"""Create a text input message from a string.
|
||||
|
||||
Args:
|
||||
text: The text to send.
|
||||
|
||||
Returns:
|
||||
A TextInputMessage instance.
|
||||
"""
|
||||
return cls(realtimeInput=RealtimeInput(text=text))
|
||||
|
||||
|
||||
class ClientContentMessage(BaseModel):
|
||||
"""Message containing client content for the API.
|
||||
|
||||
Parameters:
|
||||
clientContent: The client content to send.
|
||||
"""
|
||||
|
||||
clientContent: ClientContent
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
"""System instruction for the model.
|
||||
|
||||
Parameters:
|
||||
parts: List of content parts that make up the system instruction.
|
||||
"""
|
||||
|
||||
parts: List[ContentPart]
|
||||
|
||||
|
||||
class AudioTranscriptionConfig(BaseModel):
|
||||
"""Configuration for audio transcription."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Setup(BaseModel):
|
||||
"""Setup configuration for the Gemini Live session.
|
||||
|
||||
Parameters:
|
||||
model: Model identifier to use.
|
||||
system_instruction: System instruction for the model. Defaults to None.
|
||||
tools: List of available tools/functions. Defaults to None.
|
||||
generation_config: Generation configuration parameters. Defaults to None.
|
||||
input_audio_transcription: Input audio transcription config. Defaults to None.
|
||||
output_audio_transcription: Output audio transcription config. Defaults to None.
|
||||
realtime_input_config: Realtime input configuration. Defaults to None.
|
||||
"""
|
||||
|
||||
model: str
|
||||
system_instruction: Optional[SystemInstruction] = None
|
||||
tools: Optional[List[dict]] = None
|
||||
generation_config: Optional[dict] = None
|
||||
input_audio_transcription: Optional[AudioTranscriptionConfig] = None
|
||||
output_audio_transcription: Optional[AudioTranscriptionConfig] = None
|
||||
realtime_input_config: Optional[RealtimeInputConfig] = None
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""Configuration message for session setup.
|
||||
|
||||
Parameters:
|
||||
setup: Setup configuration for the session.
|
||||
"""
|
||||
|
||||
setup: Setup
|
||||
|
||||
|
||||
#
|
||||
# Grounding metadata models
|
||||
#
|
||||
|
||||
|
||||
class SearchEntryPoint(BaseModel):
|
||||
"""Represents the search entry point with rendered content for search suggestions."""
|
||||
|
||||
renderedContent: Optional[str] = None
|
||||
|
||||
|
||||
class WebSource(BaseModel):
|
||||
"""Represents a web source from grounding chunks."""
|
||||
|
||||
uri: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class GroundingChunk(BaseModel):
|
||||
"""Represents a grounding chunk containing web source information."""
|
||||
|
||||
web: Optional[WebSource] = None
|
||||
|
||||
|
||||
class GroundingSegment(BaseModel):
|
||||
"""Represents a segment of text that is grounded."""
|
||||
|
||||
startIndex: Optional[int] = None
|
||||
endIndex: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class GroundingSupport(BaseModel):
|
||||
"""Represents support information for grounded text segments."""
|
||||
|
||||
segment: Optional[GroundingSegment] = None
|
||||
groundingChunkIndices: Optional[List[int]] = None
|
||||
confidenceScores: Optional[List[float]] = None
|
||||
|
||||
|
||||
class GroundingMetadata(BaseModel):
|
||||
"""Represents grounding metadata from Google Search."""
|
||||
|
||||
searchEntryPoint: Optional[SearchEntryPoint] = None
|
||||
groundingChunks: Optional[List[GroundingChunk]] = None
|
||||
groundingSupports: Optional[List[GroundingSupport]] = None
|
||||
webSearchQueries: Optional[List[str]] = None
|
||||
|
||||
|
||||
#
|
||||
# Server events
|
||||
#
|
||||
|
||||
|
||||
class SetupComplete(BaseModel):
|
||||
"""Indicates that session setup is complete."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InlineData(BaseModel):
|
||||
"""Inline data embedded in server responses.
|
||||
|
||||
Parameters:
|
||||
mimeType: MIME type of the data.
|
||||
data: Base64-encoded data content.
|
||||
"""
|
||||
|
||||
mimeType: str
|
||||
data: str
|
||||
|
||||
|
||||
class Part(BaseModel):
|
||||
"""Part of a server response containing data or text.
|
||||
|
||||
Parameters:
|
||||
inlineData: Inline binary data. Defaults to None.
|
||||
text: Text content. Defaults to None.
|
||||
"""
|
||||
|
||||
inlineData: Optional[InlineData] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class ModelTurn(BaseModel):
|
||||
"""Represents a turn from the model in the conversation.
|
||||
|
||||
Parameters:
|
||||
parts: List of content parts in the model's response.
|
||||
"""
|
||||
|
||||
parts: List[Part]
|
||||
|
||||
|
||||
class ServerContentInterrupted(BaseModel):
|
||||
"""Indicates server content was interrupted.
|
||||
|
||||
Parameters:
|
||||
interrupted: Whether the content was interrupted.
|
||||
"""
|
||||
|
||||
interrupted: bool
|
||||
|
||||
|
||||
class ServerContentTurnComplete(BaseModel):
|
||||
"""Indicates the server's turn is complete.
|
||||
|
||||
Parameters:
|
||||
turnComplete: Whether the turn is complete.
|
||||
"""
|
||||
|
||||
turnComplete: bool
|
||||
|
||||
|
||||
class BidiGenerateContentTranscription(BaseModel):
|
||||
"""Transcription data from bidirectional content generation.
|
||||
|
||||
Parameters:
|
||||
text: The transcribed text content.
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class ServerContent(BaseModel):
|
||||
"""Content sent from server to client.
|
||||
|
||||
Parameters:
|
||||
modelTurn: Model's conversational turn. Defaults to None.
|
||||
interrupted: Whether content was interrupted. Defaults to None.
|
||||
turnComplete: Whether the turn is complete. Defaults to None.
|
||||
inputTranscription: Transcription of input audio. Defaults to None.
|
||||
outputTranscription: Transcription of output audio. Defaults to None.
|
||||
"""
|
||||
|
||||
modelTurn: Optional[ModelTurn] = None
|
||||
interrupted: Optional[bool] = None
|
||||
turnComplete: Optional[bool] = None
|
||||
inputTranscription: Optional[BidiGenerateContentTranscription] = None
|
||||
outputTranscription: Optional[BidiGenerateContentTranscription] = None
|
||||
groundingMetadata: Optional[GroundingMetadata] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
"""Represents a function call from the model.
|
||||
|
||||
Parameters:
|
||||
id: Unique identifier for the function call.
|
||||
name: Name of the function to call.
|
||||
args: Arguments to pass to the function.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
args: dict
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Contains one or more function calls.
|
||||
|
||||
Parameters:
|
||||
functionCalls: List of function calls to execute.
|
||||
"""
|
||||
|
||||
functionCalls: List[FunctionCall]
|
||||
|
||||
|
||||
class Modality(str, Enum):
|
||||
"""Modality types in token counts."""
|
||||
|
||||
UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||
TEXT = "TEXT"
|
||||
IMAGE = "IMAGE"
|
||||
AUDIO = "AUDIO"
|
||||
VIDEO = "VIDEO"
|
||||
|
||||
|
||||
class ModalityTokenCount(BaseModel):
|
||||
"""Token count for a specific modality.
|
||||
|
||||
Parameters:
|
||||
modality: The modality type.
|
||||
tokenCount: Number of tokens for this modality.
|
||||
"""
|
||||
|
||||
modality: Modality
|
||||
tokenCount: int
|
||||
|
||||
|
||||
class UsageMetadata(BaseModel):
|
||||
"""Usage metadata about the API response.
|
||||
|
||||
Parameters:
|
||||
promptTokenCount: Number of tokens in the prompt. Defaults to None.
|
||||
cachedContentTokenCount: Number of cached content tokens. Defaults to None.
|
||||
responseTokenCount: Number of tokens in the response. Defaults to None.
|
||||
toolUsePromptTokenCount: Number of tokens for tool use prompts. Defaults to None.
|
||||
thoughtsTokenCount: Number of tokens for model thoughts. Defaults to None.
|
||||
totalTokenCount: Total number of tokens used. Defaults to None.
|
||||
promptTokensDetails: Detailed breakdown of prompt tokens by modality. Defaults to None.
|
||||
cacheTokensDetails: Detailed breakdown of cache tokens by modality. Defaults to None.
|
||||
responseTokensDetails: Detailed breakdown of response tokens by modality. Defaults to None.
|
||||
toolUsePromptTokensDetails: Detailed breakdown of tool use tokens by modality. Defaults to None.
|
||||
"""
|
||||
|
||||
promptTokenCount: Optional[int] = None
|
||||
cachedContentTokenCount: Optional[int] = None
|
||||
responseTokenCount: Optional[int] = None
|
||||
toolUsePromptTokenCount: Optional[int] = None
|
||||
thoughtsTokenCount: Optional[int] = None
|
||||
totalTokenCount: Optional[int] = None
|
||||
promptTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
cacheTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
responseTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
toolUsePromptTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
|
||||
|
||||
class ServerEvent(BaseModel):
|
||||
"""Server event received from the Gemini Live API.
|
||||
|
||||
Parameters:
|
||||
setupComplete: Setup completion notification. Defaults to None.
|
||||
serverContent: Content from the server. Defaults to None.
|
||||
toolCall: Tool/function call request. Defaults to None.
|
||||
usageMetadata: Token usage metadata. Defaults to None.
|
||||
"""
|
||||
|
||||
setupComplete: Optional[SetupComplete] = None
|
||||
serverContent: Optional[ServerContent] = None
|
||||
toolCall: Optional[ToolCall] = None
|
||||
usageMetadata: Optional[UsageMetadata] = None
|
||||
|
||||
|
||||
def parse_server_event(str):
|
||||
"""Parse a server event from JSON string.
|
||||
|
||||
Args:
|
||||
str: JSON string containing the server event.
|
||||
|
||||
Returns:
|
||||
ServerEvent instance if parsing succeeds, None otherwise.
|
||||
"""
|
||||
try:
|
||||
evt = json.loads(str)
|
||||
return ServerEvent.model_validate(evt)
|
||||
except Exception as e:
|
||||
print(f"Error parsing server event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class ContextWindowCompressionConfig(BaseModel):
|
||||
"""Configuration for context window compression.
|
||||
|
||||
Parameters:
|
||||
sliding_window: Whether to use sliding window compression. Defaults to True.
|
||||
trigger_tokens: Token count threshold to trigger compression. Defaults to None.
|
||||
"""
|
||||
|
||||
sliding_window: Optional[bool] = Field(default=True)
|
||||
trigger_tokens: Optional[int] = Field(default=None)
|
||||
"""Event models and utilities for Google Gemini Multimodal Live API.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Importing StartSensitivity and EndSensitivity from this module is deprecated.
|
||||
Import them directly from google.genai.types instead.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from google.genai.types import (
|
||||
EndSensitivity as _EndSensitivity,
|
||||
)
|
||||
from google.genai.types import (
|
||||
StartSensitivity as _StartSensitivity,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# These aliases are just here for backward compatibility, since we used to
|
||||
# define public-facing StartSensitivity and EndSensitivity enums in this
|
||||
# module.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Importing StartSensitivity and EndSensitivity from "
|
||||
"pipecat.services.gemini_multimodal_live.events is deprecated. "
|
||||
"Please import them directly from google.genai.types instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
StartSensitivity = _StartSensitivity
|
||||
EndSensitivity = _EndSensitivity
|
||||
|
||||
@@ -9,181 +9,31 @@
|
||||
This module provides a client for Google's Gemini File API, enabling file
|
||||
uploads, metadata retrieval, listing, and deletion. Files uploaded through
|
||||
this API can be referenced in Gemini generative model calls.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Importing GeminiFileAPI from this module is deprecated.
|
||||
Import it from pipecat.services.google.gemini_live.file_api instead.
|
||||
"""
|
||||
|
||||
import mimetypes
|
||||
from typing import Any, Dict, Optional
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from pipecat.services.google.gemini_live.file_api import GeminiFileAPI as _GeminiFileAPI
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
class GeminiFileAPI:
|
||||
"""Client for the Gemini File API.
|
||||
|
||||
This class provides methods for uploading, fetching, listing, and deleting files
|
||||
through Google's Gemini File API.
|
||||
|
||||
Files uploaded through this API remain available for 48 hours and can be referenced
|
||||
in calls to the Gemini generative models. Maximum file size is 2GB, with total
|
||||
project storage limited to 20GB.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com/v1beta/files"
|
||||
):
|
||||
"""Initialize the Gemini File API client.
|
||||
|
||||
Args:
|
||||
api_key: Google AI API key
|
||||
base_url: Base URL for the Gemini File API (default is the v1beta endpoint)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
# Upload URL uses the /upload/ path
|
||||
self.upload_base_url = "https://generativelanguage.googleapis.com/upload/v1beta/files"
|
||||
|
||||
async def upload_file(
|
||||
self, file_path: str, display_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload a file to the Gemini File API using the correct resumable upload protocol.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to upload
|
||||
display_name: Optional display name for the file
|
||||
|
||||
Returns:
|
||||
File metadata including uri, name, and display_name
|
||||
"""
|
||||
logger.info(f"Uploading file: {file_path}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Determine the file's MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Read the file
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Create the metadata payload
|
||||
metadata = {}
|
||||
if display_name:
|
||||
metadata = {"file": {"display_name": display_name}}
|
||||
|
||||
# Step 1: Initial resumable request to get upload URL
|
||||
headers = {
|
||||
"X-Goog-Upload-Protocol": "resumable",
|
||||
"X-Goog-Upload-Command": "start",
|
||||
"X-Goog-Upload-Header-Content-Length": str(len(file_data)),
|
||||
"X-Goog-Upload-Header-Content-Type": mime_type,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(f"Step 1: Getting upload URL from {self.upload_base_url}")
|
||||
async with session.post(
|
||||
f"{self.upload_base_url}?key={self._api_key}", headers=headers, json=metadata
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error initiating file upload: {error_text}")
|
||||
raise Exception(f"Failed to initiate upload: {response.status} - {error_text}")
|
||||
|
||||
# Get the upload URL from the response header
|
||||
upload_url = response.headers.get("X-Goog-Upload-URL")
|
||||
if not upload_url:
|
||||
logger.error(f"Response headers: {dict(response.headers)}")
|
||||
raise Exception("No upload URL in response headers")
|
||||
|
||||
logger.debug(f"Got upload URL: {upload_url}")
|
||||
|
||||
# Step 2: Upload the actual file data
|
||||
upload_headers = {
|
||||
"Content-Length": str(len(file_data)),
|
||||
"X-Goog-Upload-Offset": "0",
|
||||
"X-Goog-Upload-Command": "upload, finalize",
|
||||
}
|
||||
|
||||
logger.debug(f"Step 2: Uploading file data to {upload_url}")
|
||||
async with session.post(upload_url, headers=upload_headers, data=file_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error uploading file data: {error_text}")
|
||||
raise Exception(f"Failed to upload file: {response.status} - {error_text}")
|
||||
|
||||
file_info = await response.json()
|
||||
logger.info(f"File uploaded successfully: {file_info.get('file', {}).get('name')}")
|
||||
return file_info
|
||||
|
||||
async def get_file(self, name: str) -> Dict[str, Any]:
|
||||
"""Get metadata for a file.
|
||||
|
||||
Args:
|
||||
name: File name (or full path)
|
||||
|
||||
Returns:
|
||||
File metadata
|
||||
"""
|
||||
# Extract just the name part if a full path is provided
|
||||
if "/" in name:
|
||||
name = name.split("/")[-1]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self._base_url}/{name}?key={self._api_key}") as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error getting file metadata: {error_text}")
|
||||
raise Exception(f"Failed to get file metadata: {response.status}")
|
||||
|
||||
file_info = await response.json()
|
||||
return file_info
|
||||
|
||||
async def list_files(
|
||||
self, page_size: int = 10, page_token: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""List uploaded files.
|
||||
|
||||
Args:
|
||||
page_size: Number of files to return per page
|
||||
page_token: Token for pagination
|
||||
|
||||
Returns:
|
||||
List of files and next page token if available
|
||||
"""
|
||||
params = {"key": self._api_key, "pageSize": page_size}
|
||||
|
||||
if page_token:
|
||||
params["pageToken"] = page_token
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self._base_url, params=params) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error listing files: {error_text}")
|
||||
raise Exception(f"Failed to list files: {response.status}")
|
||||
|
||||
result = await response.json()
|
||||
return result
|
||||
|
||||
async def delete_file(self, name: str) -> bool:
|
||||
"""Delete a file.
|
||||
|
||||
Args:
|
||||
name: File name (or full path)
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
"""
|
||||
# Extract just the name part if a full path is provided
|
||||
if "/" in name:
|
||||
name = name.split("/")[-1]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(f"{self._base_url}/{name}?key={self._api_key}") as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error deleting file: {error_text}")
|
||||
raise Exception(f"Failed to delete file: {response.status}")
|
||||
|
||||
return True
|
||||
# These aliases are just here for backward compatibility, since we used to
|
||||
# define public-facing StartSensitivity and EndSensitivity enums in this
|
||||
# module.
|
||||
warnings.warn(
|
||||
"Importing GeminiFileAPI from "
|
||||
"pipecat.services.gemini_multimodal_live.file_api is deprecated. "
|
||||
"Please import it from pipecat.services.google.gemini_live.file_api instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
GeminiFileAPI = _GeminiFileAPI
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ import sys
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .frames import *
|
||||
from .gemini_live import *
|
||||
from .image import *
|
||||
from .llm import *
|
||||
from .llm_openai import *
|
||||
|
||||
3
src/pipecat/services/google/gemini_live/__init__.py
Normal file
3
src/pipecat/services/google/gemini_live/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .file_api import GeminiFileAPI
|
||||
from .llm import GeminiLiveLLMService
|
||||
from .llm_vertex import GeminiLiveVertexLLMService
|
||||
189
src/pipecat/services/google/gemini_live/file_api.py
Normal file
189
src/pipecat/services/google/gemini_live/file_api.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gemini File API client for uploading and managing files.
|
||||
|
||||
This module provides a client for Google's Gemini File API, enabling file
|
||||
uploads, metadata retrieval, listing, and deletion. Files uploaded through
|
||||
this API can be referenced in Gemini generative model calls.
|
||||
"""
|
||||
|
||||
import mimetypes
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GeminiFileAPI:
|
||||
"""Client for the Gemini File API.
|
||||
|
||||
This class provides methods for uploading, fetching, listing, and deleting files
|
||||
through Google's Gemini File API.
|
||||
|
||||
Files uploaded through this API remain available for 48 hours and can be referenced
|
||||
in calls to the Gemini generative models. Maximum file size is 2GB, with total
|
||||
project storage limited to 20GB.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com/v1beta/files"
|
||||
):
|
||||
"""Initialize the Gemini File API client.
|
||||
|
||||
Args:
|
||||
api_key: Google AI API key
|
||||
base_url: Base URL for the Gemini File API (default is the v1beta endpoint)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
# Upload URL uses the /upload/ path
|
||||
self.upload_base_url = "https://generativelanguage.googleapis.com/upload/v1beta/files"
|
||||
|
||||
async def upload_file(
|
||||
self, file_path: str, display_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload a file to the Gemini File API using the correct resumable upload protocol.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to upload
|
||||
display_name: Optional display name for the file
|
||||
|
||||
Returns:
|
||||
File metadata including uri, name, and display_name
|
||||
"""
|
||||
logger.info(f"Uploading file: {file_path}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Determine the file's MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Read the file
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Create the metadata payload
|
||||
metadata = {}
|
||||
if display_name:
|
||||
metadata = {"file": {"display_name": display_name}}
|
||||
|
||||
# Step 1: Initial resumable request to get upload URL
|
||||
headers = {
|
||||
"X-Goog-Upload-Protocol": "resumable",
|
||||
"X-Goog-Upload-Command": "start",
|
||||
"X-Goog-Upload-Header-Content-Length": str(len(file_data)),
|
||||
"X-Goog-Upload-Header-Content-Type": mime_type,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(f"Step 1: Getting upload URL from {self.upload_base_url}")
|
||||
async with session.post(
|
||||
f"{self.upload_base_url}?key={self._api_key}", headers=headers, json=metadata
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error initiating file upload: {error_text}")
|
||||
raise Exception(f"Failed to initiate upload: {response.status} - {error_text}")
|
||||
|
||||
# Get the upload URL from the response header
|
||||
upload_url = response.headers.get("X-Goog-Upload-URL")
|
||||
if not upload_url:
|
||||
logger.error(f"Response headers: {dict(response.headers)}")
|
||||
raise Exception("No upload URL in response headers")
|
||||
|
||||
logger.debug(f"Got upload URL: {upload_url}")
|
||||
|
||||
# Step 2: Upload the actual file data
|
||||
upload_headers = {
|
||||
"Content-Length": str(len(file_data)),
|
||||
"X-Goog-Upload-Offset": "0",
|
||||
"X-Goog-Upload-Command": "upload, finalize",
|
||||
}
|
||||
|
||||
logger.debug(f"Step 2: Uploading file data to {upload_url}")
|
||||
async with session.post(upload_url, headers=upload_headers, data=file_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error uploading file data: {error_text}")
|
||||
raise Exception(f"Failed to upload file: {response.status} - {error_text}")
|
||||
|
||||
file_info = await response.json()
|
||||
logger.info(f"File uploaded successfully: {file_info.get('file', {}).get('name')}")
|
||||
return file_info
|
||||
|
||||
async def get_file(self, name: str) -> Dict[str, Any]:
|
||||
"""Get metadata for a file.
|
||||
|
||||
Args:
|
||||
name: File name (or full path)
|
||||
|
||||
Returns:
|
||||
File metadata
|
||||
"""
|
||||
# Extract just the name part if a full path is provided
|
||||
if "/" in name:
|
||||
name = name.split("/")[-1]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self._base_url}/{name}?key={self._api_key}") as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error getting file metadata: {error_text}")
|
||||
raise Exception(f"Failed to get file metadata: {response.status}")
|
||||
|
||||
file_info = await response.json()
|
||||
return file_info
|
||||
|
||||
async def list_files(
|
||||
self, page_size: int = 10, page_token: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""List uploaded files.
|
||||
|
||||
Args:
|
||||
page_size: Number of files to return per page
|
||||
page_token: Token for pagination
|
||||
|
||||
Returns:
|
||||
List of files and next page token if available
|
||||
"""
|
||||
params = {"key": self._api_key, "pageSize": page_size}
|
||||
|
||||
if page_token:
|
||||
params["pageToken"] = page_token
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self._base_url, params=params) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error listing files: {error_text}")
|
||||
raise Exception(f"Failed to list files: {response.status}")
|
||||
|
||||
result = await response.json()
|
||||
return result
|
||||
|
||||
async def delete_file(self, name: str) -> bool:
|
||||
"""Delete a file.
|
||||
|
||||
Args:
|
||||
name: File name (or full path)
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
"""
|
||||
# Extract just the name part if a full path is provided
|
||||
if "/" in name:
|
||||
name = name.split("/")[-1]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(f"{self._base_url}/{name}?key={self._api_key}") as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error deleting file: {error_text}")
|
||||
raise Exception(f"Failed to delete file: {response.status}")
|
||||
|
||||
return True
|
||||
1582
src/pipecat/services/google/gemini_live/llm.py
Normal file
1582
src/pipecat/services/google/gemini_live/llm.py
Normal file
File diff suppressed because it is too large
Load Diff
184
src/pipecat/services/google/gemini_live/llm_vertex.py
Normal file
184
src/pipecat/services/google/gemini_live/llm_vertex.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Service for accessing Gemini Live via Google Vertex AI.
|
||||
|
||||
This module provides integration with Google's Gemini Live model via
|
||||
Vertex AI, supporting both text and audio modalities with voice transcription,
|
||||
streaming responses, and tool usage.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.services.google.gemini_live.llm import (
|
||||
GeminiLiveLLMService,
|
||||
HttpOptions,
|
||||
InputParams,
|
||||
)
|
||||
|
||||
try:
|
||||
from google.auth import default
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.auth.transport.requests import Request
|
||||
from google.genai import Client
|
||||
from google.oauth2 import service_account
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google Vertex AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GeminiLiveVertexLLMService(GeminiLiveLLMService):
|
||||
"""Provides access to Google's Gemini Live model via Vertex AI.
|
||||
|
||||
This service enables real-time conversations with Gemini, supporting both
|
||||
text and audio modalities. It handles voice transcription, streaming audio
|
||||
responses, and tool usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: str,
|
||||
project_id: str,
|
||||
model="google/gemini-2.0-flash-live-preview-04-09",
|
||||
voice_id: str = "Charon",
|
||||
start_audio_paused: bool = False,
|
||||
start_video_paused: bool = False,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[Union[List[dict], ToolsSchema]] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
inference_on_context_initialization: bool = True,
|
||||
file_api_base_url: str = "https://generativelanguage.googleapis.com/v1beta/files",
|
||||
http_options: Optional[HttpOptions] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the service for accessing Gemini Live via Google Vertex AI.
|
||||
|
||||
Args:
|
||||
credentials: JSON string of service account credentials.
|
||||
credentials_path: Path to the service account JSON file.
|
||||
location: GCP region for Vertex AI endpoint (e.g., "us-east4").
|
||||
project_id: Google Cloud project ID.
|
||||
model: Model identifier to use. Defaults to "models/gemini-2.0-flash-live-preview-04-09".
|
||||
voice_id: TTS voice identifier. Defaults to "Charon".
|
||||
start_audio_paused: Whether to start with audio input paused. Defaults to False.
|
||||
start_video_paused: Whether to start with video input paused. Defaults to False.
|
||||
system_instruction: System prompt for the model. Defaults to None.
|
||||
tools: Tools/functions available to the model. Defaults to None.
|
||||
params: Configuration parameters for the model along with Vertex AI
|
||||
location and project ID.
|
||||
inference_on_context_initialization: Whether to generate a response when context
|
||||
is first set. Defaults to True.
|
||||
file_api_base_url: Base URL for the Gemini File API. Defaults to the official endpoint.
|
||||
http_options: HTTP options for the client.
|
||||
**kwargs: Additional arguments passed to parent GeminiLiveLLMService.
|
||||
"""
|
||||
# Check if user incorrectly passed api_key, which is used by parent
|
||||
# class but not here.
|
||||
if "api_key" in kwargs:
|
||||
logger.error(
|
||||
"GeminiLiveVertexLLMService does not accept 'api_key' parameter. "
|
||||
"Use 'credentials' or 'credentials_path' instead for Vertex AI authentication."
|
||||
)
|
||||
raise ValueError(
|
||||
"Invalid parameter 'api_key'. Use 'credentials' or 'credentials_path' for Vertex AI authentication."
|
||||
)
|
||||
|
||||
# These need to be set before calling super().__init__() because
|
||||
# super().__init__() invokes create_client(), which needs these.
|
||||
self._credentials = self._get_credentials(credentials, credentials_path)
|
||||
self._project_id = project_id
|
||||
self._location = location
|
||||
|
||||
# Call parent constructor with the obtained API key
|
||||
super().__init__(
|
||||
# api_key is required by parent class, but actually not used with
|
||||
# Vertex
|
||||
api_key="dummy",
|
||||
model=model,
|
||||
voice_id=voice_id,
|
||||
start_audio_paused=start_audio_paused,
|
||||
start_video_paused=start_video_paused,
|
||||
system_instruction=system_instruction,
|
||||
tools=tools,
|
||||
params=params,
|
||||
inference_on_context_initialization=inference_on_context_initialization,
|
||||
file_api_base_url=file_api_base_url,
|
||||
http_options=http_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_client(self):
|
||||
"""Create the Gemini client instance."""
|
||||
self._client = Client(
|
||||
vertexai=True,
|
||||
credentials=self._credentials,
|
||||
project=self._project_id,
|
||||
location=self._location,
|
||||
)
|
||||
|
||||
@property
|
||||
def file_api(self):
|
||||
"""Gemini File API is not supported with Vertex AI."""
|
||||
raise NotImplementedError(
|
||||
"When using Vertex AI, the recommended approach is to use Google Cloud Storage for file handling. The Gemini File API is not directly supported in this context."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_credentials(credentials: Optional[str], credentials_path: Optional[str]) -> str:
|
||||
"""Retrieve Credentials using Google service account credentials JSON.
|
||||
|
||||
Supports multiple authentication methods:
|
||||
1. Direct JSON credentials string
|
||||
2. Path to service account JSON file
|
||||
3. Default application credentials (ADC)
|
||||
|
||||
Args:
|
||||
credentials: JSON string of service account credentials.
|
||||
credentials_path: Path to the service account JSON file.
|
||||
|
||||
Returns:
|
||||
OAuth token for API authentication.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid credentials are provided or found.
|
||||
"""
|
||||
creds: Optional[service_account.Credentials] = None
|
||||
|
||||
if credentials:
|
||||
# Parse and load credentials from JSON string
|
||||
creds = service_account.Credentials.from_service_account_info(
|
||||
json.loads(credentials),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
elif credentials_path:
|
||||
# Load credentials from JSON file
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
credentials_path,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
creds, project_id = default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
except GoogleAuthError:
|
||||
pass
|
||||
|
||||
if not creds:
|
||||
raise ValueError("No valid credentials provided.")
|
||||
|
||||
creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour.
|
||||
|
||||
return creds
|
||||
@@ -94,9 +94,9 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
prompt_tokens=chunk.usage.prompt_tokens or 0,
|
||||
completion_tokens=chunk.usage.completion_tokens or 0,
|
||||
total_tokens=chunk.usage.total_tokens or 0,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
|
||||
@@ -53,12 +53,44 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
|
||||
Parameters:
|
||||
location: GCP region for Vertex AI endpoint (e.g., "us-east4").
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Use `location` as a direct argument to
|
||||
`GoogleVertexLLMService.__init__()` instead.
|
||||
|
||||
project_id: Google Cloud project ID.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Use `project_id` as a direct argument to
|
||||
`GoogleVertexLLMService.__init__()` instead.
|
||||
"""
|
||||
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations
|
||||
location: str = "us-east4"
|
||||
project_id: str
|
||||
location: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes the InputParams."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
if "location" in kwargs and kwargs["location"] is not None:
|
||||
warnings.warn(
|
||||
"GoogleVertexLLMService.InputParams.location is deprecated. "
|
||||
"Please provide 'location' as a direct argument to GoogleVertexLLMService.__init__() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if "project_id" in kwargs and kwargs["project_id"] is not None:
|
||||
warnings.warn(
|
||||
"GoogleVertexLLMService.InputParams.project_id is deprecated. "
|
||||
"Please provide 'project_id' as a direct argument to GoogleVertexLLMService.__init__() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -66,7 +98,8 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
model: str = "google/gemini-2.0-flash-001",
|
||||
params: Optional[InputParams] = None,
|
||||
location: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the VertexLLMService.
|
||||
@@ -75,33 +108,60 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
credentials: JSON string of service account credentials.
|
||||
credentials_path: Path to the service account JSON file.
|
||||
model: Model identifier (e.g., "google/gemini-2.0-flash-001").
|
||||
params: Vertex AI input parameters including location and project.
|
||||
location: GCP region for Vertex AI endpoint (e.g., "us-east4").
|
||||
project_id: Google Cloud project ID.
|
||||
**kwargs: Additional arguments passed to OpenAILLMService.
|
||||
"""
|
||||
params = params or OpenAILLMService.InputParams()
|
||||
base_url = self._get_base_url(params)
|
||||
# Handle deprecated InputParams fields
|
||||
if "params" in kwargs and isinstance(kwargs["params"], GoogleVertexLLMService.InputParams):
|
||||
params = kwargs["params"]
|
||||
# Extract location and project_id from params if not provided
|
||||
# directly, for backward compatibility
|
||||
if project_id is None:
|
||||
project_id = params.project_id
|
||||
if location is None:
|
||||
location = params.location
|
||||
# Convert to base InputParams
|
||||
params = OpenAILLMService.InputParams(
|
||||
**params.model_dump(exclude={"location", "project_id"}, exclude_unset=True)
|
||||
)
|
||||
kwargs["params"] = params
|
||||
|
||||
# Validate project_id and location parameters
|
||||
# NOTE: once we remove Vertex-spcific InputParams class, we can update
|
||||
# __init__() signature as follows:
|
||||
# - location: str = "us-east4",
|
||||
# - project_id: str,
|
||||
# But for now, we need them as-is to maintain proper backward
|
||||
# compatibility.
|
||||
if project_id is None:
|
||||
raise ValueError("project_id is required")
|
||||
if location is None:
|
||||
# If location is not provided, default to "us-east4".
|
||||
# Note: this is legacy behavior; ideally location would be
|
||||
# required.
|
||||
logger.warning("location is not provided. Defaulting to 'us-east4'.")
|
||||
location = "us-east4" # Default location if not provided
|
||||
|
||||
base_url = self._get_base_url(location, project_id)
|
||||
self._api_key = self._get_api_token(credentials, credentials_path)
|
||||
|
||||
super().__init__(
|
||||
api_key=self._api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_base_url(params: InputParams) -> str:
|
||||
def _get_base_url(location: str, project_id: str) -> str:
|
||||
"""Construct the base URL for Vertex AI API."""
|
||||
# Determine the correct API host based on location
|
||||
if params.location == "global":
|
||||
if location == "global":
|
||||
api_host = "aiplatform.googleapis.com"
|
||||
else:
|
||||
api_host = f"{params.location}-aiplatform.googleapis.com"
|
||||
return (
|
||||
f"https://{api_host}/v1/"
|
||||
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
|
||||
)
|
||||
api_host = f"{location}-aiplatform.googleapis.com"
|
||||
return f"https://{api_host}/v1/projects/{project_id}/locations/{location}/endpoints/openapi"
|
||||
|
||||
@staticmethod
|
||||
def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
|
||||
|
||||
@@ -730,6 +730,8 @@ class GoogleSTTService(STTService):
|
||||
self._request_queue = asyncio.Queue()
|
||||
self._streaming_task = self.create_task(self._stream_audio())
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Clean up streaming recognition resources."""
|
||||
if self._streaming_task:
|
||||
@@ -737,6 +739,8 @@ class GoogleSTTService(STTService):
|
||||
await self.cancel_task(self._streaming_task)
|
||||
self._streaming_task = None
|
||||
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _request_generator(self):
|
||||
"""Generates requests for the streaming recognize method."""
|
||||
recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_"
|
||||
|
||||
@@ -42,7 +42,7 @@ class HumeTTSService(TTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
|
||||
Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
|
||||
using the Python SDK and emits `TTSAudioRawFrame`s suitable for Pipecat transports.
|
||||
using the Python SDK and emits ``TTSAudioRawFrame`` frames suitable for Pipecat transports.
|
||||
|
||||
Supported features:
|
||||
|
||||
@@ -78,7 +78,7 @@ class HumeTTSService(TTSService):
|
||||
|
||||
Args:
|
||||
api_key: Hume API key. If omitted, reads the ``HUME_API_KEY`` environment variable.
|
||||
voice_id: ID of the voice to use (ID-only; names are not supported here).
|
||||
voice_id: ID of the voice to use. Only voice IDs are supported; voice names are not.
|
||||
params: Optional synthesis controls (acting instructions, speed, trailing silence).
|
||||
sample_rate: Output sample rate for emitted PCM frames. Defaults to 48_000 (Hume).
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
|
||||
@@ -222,6 +222,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
# Send initialization message
|
||||
await self._websocket.send(json.dumps(init_msg))
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -243,6 +244,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get the WebSocket connection if available."""
|
||||
|
||||
@@ -293,6 +293,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
headers = {"x-api-key": self._api_key}
|
||||
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -311,6 +313,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process messages from Neuphonic WebSocket."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .image import *
|
||||
from .llm import *
|
||||
from .realtime import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0).
|
||||
max_tokens: Maximum tokens in response (deprecated, use max_completion_tokens).
|
||||
max_completion_tokens: Maximum completion tokens to generate.
|
||||
service_tier: Service tier to use (e.g., "auto", "flex", "priority").
|
||||
extra: Additional model-specific parameters.
|
||||
"""
|
||||
|
||||
@@ -83,6 +84,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
|
||||
max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
|
||||
service_tier: Optional[str] = Field(default_factory=lambda: NOT_GIVEN)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
@@ -125,6 +127,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
"top_p": params.top_p,
|
||||
"max_tokens": params.max_tokens,
|
||||
"max_completion_tokens": params.max_completion_tokens,
|
||||
"service_tier": params.service_tier,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self._retry_timeout_secs = retry_timeout_secs
|
||||
@@ -236,6 +239,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
"top_p": self._settings["top_p"],
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"max_completion_tokens": self._settings["max_completion_tokens"],
|
||||
"service_tier": self._settings["service_tier"],
|
||||
}
|
||||
|
||||
# Messages, tools, tool_choice
|
||||
|
||||
0
src/pipecat/services/openai/realtime/__init__.py
Normal file
0
src/pipecat/services/openai/realtime/__init__.py
Normal file
272
src/pipecat/services/openai/realtime/context.py
Normal file
272
src/pipecat/services/openai/realtime/context.py
Normal file
@@ -0,0 +1,272 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime LLM context and aggregator implementations."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
from . import events
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
|
||||
class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
"""OpenAI Realtime LLM context with session management and message conversion.
|
||||
|
||||
Extends the standard OpenAI LLM context to support real-time session properties,
|
||||
instruction management, and conversion between standard message formats and
|
||||
realtime conversation items.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize the OpenAIRealtimeLLMContext.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages. Defaults to None.
|
||||
tools: Available function tools. Defaults to None.
|
||||
**kwargs: Additional arguments passed to parent OpenAILLMContext.
|
||||
"""
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self):
|
||||
self.llm_needs_settings_update = True
|
||||
self.llm_needs_initial_messages = True
|
||||
self._session_instructions = ""
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
"""Upgrade a standard OpenAI LLM context to a realtime context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAILLMContext instance to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded OpenAIRealtimeLLMContext instance.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
|
||||
obj.__class__ = OpenAIRealtimeLLMContext
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# todo
|
||||
# - finish implementing all frames
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert a standard message format to a realtime conversation item.
|
||||
|
||||
Args:
|
||||
message: The standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
A ConversationItem instance for the realtime API.
|
||||
"""
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
return events.ConversationItem(
|
||||
role="user",
|
||||
type="message",
|
||||
content=[events.ItemContent(type="input_text", text=content)],
|
||||
)
|
||||
if message.get("role") == "assistant" and message.get("tool_calls"):
|
||||
tc = message.get("tool_calls")[0]
|
||||
return events.ConversationItem(
|
||||
type="function_call",
|
||||
call_id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
)
|
||||
logger.error(f"Unhandled message type in from_standard_message: {message}")
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
"""Get conversation items for initializing the realtime session history.
|
||||
|
||||
Converts the context's messages to a format suitable for the realtime API,
|
||||
handling system instructions and conversation history packaging.
|
||||
|
||||
Returns:
|
||||
List of conversation items for session initialization.
|
||||
"""
|
||||
# We can't load a long conversation history into the openai realtime api yet. (The API/model
|
||||
# forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So
|
||||
# our general strategy until this is fixed is just to put everything into a first "user"
|
||||
# message as a single input.
|
||||
if not self.messages:
|
||||
return []
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into session
|
||||
# "instructions"
|
||||
if messages[0].get("role") == "system":
|
||||
self.llm_needs_settings_update = True
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
self._session_instructions = content
|
||||
elif isinstance(content, list):
|
||||
self._session_instructions = content[0].get("text")
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# If we have just a single "user" item, we can just send it normally
|
||||
if len(messages) == 1 and messages[0].get("role") == "user":
|
||||
return [self.from_standard_message(messages[0])]
|
||||
|
||||
# Otherwise, let's pack everything into a single "user" message with a bit of
|
||||
# explanation for the LLM
|
||||
intro_text = """
|
||||
This is a previously saved conversation. Please treat this conversation history as a
|
||||
starting point for the current conversation."""
|
||||
|
||||
trailing_text = """
|
||||
This is the end of the previously saved conversation. Please continue the conversation
|
||||
from here. If the last message is a user instruction or question, act on that instruction
|
||||
or answer the question. If the last message is an assistant response, simple say that you
|
||||
are ready to continue the conversation."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "\n\n".join(
|
||||
[intro_text, json.dumps(messages, indent=2), trailing_text]
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def add_user_content_item_as_message(self, item):
|
||||
"""Add a user content item as a standard message to the context.
|
||||
|
||||
Args:
|
||||
item: The conversation item to add as a user message.
|
||||
"""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": item.content[0].transcript}],
|
||||
}
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""User context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles user input frames and generates appropriate context updates
|
||||
for the realtime conversation, including message updates and tool settings.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process incoming frames and handle realtime-specific frame types.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
|
||||
# messages are only processed by the user context aggregator, which is generally what we want. But
|
||||
# we also need to send new messages over the websocket, so the openai realtime API has them
|
||||
# in its context.
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(RealtimeMessagesUpdateFrame(context=self._context))
|
||||
|
||||
# Parent also doesn't push the LLMSetToolsFrame.
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push user input aggregation.
|
||||
|
||||
Currently ignores all user input coming into the pipeline as realtime
|
||||
audio input is handled directly by the service.
|
||||
"""
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Assistant context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles assistant output frames from the realtime service, filtering
|
||||
out duplicate text frames and managing function call results.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# OpenAIRealtimeLLMService also pushes TranscriptionFrames and InterimTranscriptionFrames,
|
||||
# so we need to ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process assistant frames, filtering out duplicate text content.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result and notify the realtime service.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
1106
src/pipecat/services/openai/realtime/events.py
Normal file
1106
src/pipecat/services/openai/realtime/events.py
Normal file
File diff suppressed because it is too large
Load Diff
37
src/pipecat/services/openai/realtime/frames.py
Normal file
37
src/pipecat/services/openai/realtime/frames.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frame types for OpenAI Realtime API integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.openai.realtime.context import OpenAIRealtimeLLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeMessagesUpdateFrame(DataFrame):
|
||||
"""Frame indicating that the realtime context messages have been updated.
|
||||
|
||||
Parameters:
|
||||
context: The updated OpenAI realtime LLM context.
|
||||
"""
|
||||
|
||||
context: "OpenAIRealtimeLLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call results for the realtime service.
|
||||
|
||||
Parameters:
|
||||
result_frame: The function call result frame to send to the realtime API.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
@@ -14,6 +14,7 @@ from typing import AsyncGenerator, Dict, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI, BadRequestError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
@@ -55,6 +56,17 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for OpenAI TTS configuration.
|
||||
|
||||
Parameters:
|
||||
instructions: Instructions to guide voice synthesis behavior.
|
||||
speed: Voice speed control (0.25 to 4.0, default 1.0).
|
||||
"""
|
||||
|
||||
instructions: Optional[str] = None
|
||||
speed: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -65,6 +77,7 @@ class OpenAITTSService(TTSService):
|
||||
sample_rate: Optional[int] = None,
|
||||
instructions: Optional[str] = None,
|
||||
speed: Optional[float] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI TTS service.
|
||||
@@ -77,7 +90,11 @@ class OpenAITTSService(TTSService):
|
||||
sample_rate: Output audio sample rate in Hz. If None, uses OpenAI's default 24kHz.
|
||||
instructions: Optional instructions to guide voice synthesis behavior.
|
||||
speed: Voice speed control (0.25 to 4.0, default 1.0).
|
||||
params: Optional synthesis controls (acting instructions, speed, ...).
|
||||
**kwargs: Additional keyword arguments passed to TTSService.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
The `instructions` and `speed` parameters are deprecated, use `InputParams` instead.
|
||||
"""
|
||||
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
@@ -86,12 +103,26 @@ class OpenAITTSService(TTSService):
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._speed = speed
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice)
|
||||
self._instructions = instructions
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
if instructions or speed:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The `instructions` and `speed` parameters are deprecated, use `InputParams` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._settings = {
|
||||
"instructions": params.instructions if params else instructions,
|
||||
"speed": params.speed if params else speed,
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
@@ -144,11 +175,11 @@ class OpenAITTSService(TTSService):
|
||||
"response_format": "pcm",
|
||||
}
|
||||
|
||||
if self._instructions:
|
||||
create_params["instructions"] = self._instructions
|
||||
if self._settings["instructions"]:
|
||||
create_params["instructions"] = self._settings["instructions"]
|
||||
|
||||
if self._speed:
|
||||
create_params["speed"] = self._speed
|
||||
if self._settings["speed"]:
|
||||
create_params["speed"] = self._settings["speed"]
|
||||
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
**create_params
|
||||
|
||||
@@ -1,9 +1,27 @@
|
||||
from .azure import AzureRealtimeLLMService
|
||||
from .events import (
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.services.azure.realtime.llm import AzureRealtimeLLMService
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from .openai import OpenAIRealtimeLLMService
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai_realtime are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.openai.realtime instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -1,67 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Azure OpenAI Realtime LLM service implementation."""
|
||||
|
||||
from loguru import logger
|
||||
import warnings
|
||||
|
||||
from .openai import OpenAIRealtimeLLMService
|
||||
from pipecat.services.azure.realtime.llm import *
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai_realtime.azure are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.azure.realtime.llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AzureRealtimeLLMService(OpenAIRealtimeLLMService):
|
||||
"""Azure OpenAI Realtime LLM service with Azure-specific authentication.
|
||||
|
||||
Extends the OpenAI Realtime service to work with Azure OpenAI endpoints,
|
||||
using Azure's authentication headers and endpoint format. Provides the same
|
||||
real-time audio and text communication capabilities as the base OpenAI service.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Azure Realtime LLM service.
|
||||
|
||||
Args:
|
||||
api_key: The API key for the Azure OpenAI service.
|
||||
base_url: The full Azure WebSocket endpoint URL including api-version and deployment.
|
||||
Example: "wss://my-project.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-realtime-deployment"
|
||||
**kwargs: Additional arguments passed to parent OpenAIRealtimeLLMService.
|
||||
"""
|
||||
super().__init__(base_url=base_url, api_key=api_key, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
if self._websocket:
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
|
||||
logger.info(f"Connecting to {self.base_url}, api key: {self.api_key}")
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self.base_url,
|
||||
additional_headers={
|
||||
"api-key": self.api_key,
|
||||
},
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
@@ -1,272 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime LLM context and aggregator implementations."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.services.openai.realtime.context import *
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
from . import events
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
|
||||
class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
"""OpenAI Realtime LLM context with session management and message conversion.
|
||||
|
||||
Extends the standard OpenAI LLM context to support real-time session properties,
|
||||
instruction management, and conversion between standard message formats and
|
||||
realtime conversation items.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize the OpenAIRealtimeLLMContext.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages. Defaults to None.
|
||||
tools: Available function tools. Defaults to None.
|
||||
**kwargs: Additional arguments passed to parent OpenAILLMContext.
|
||||
"""
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self):
|
||||
self.llm_needs_settings_update = True
|
||||
self.llm_needs_initial_messages = True
|
||||
self._session_instructions = ""
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
"""Upgrade a standard OpenAI LLM context to a realtime context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAILLMContext instance to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded OpenAIRealtimeLLMContext instance.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
|
||||
obj.__class__ = OpenAIRealtimeLLMContext
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# todo
|
||||
# - finish implementing all frames
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert a standard message format to a realtime conversation item.
|
||||
|
||||
Args:
|
||||
message: The standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
A ConversationItem instance for the realtime API.
|
||||
"""
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
return events.ConversationItem(
|
||||
role="user",
|
||||
type="message",
|
||||
content=[events.ItemContent(type="input_text", text=content)],
|
||||
)
|
||||
if message.get("role") == "assistant" and message.get("tool_calls"):
|
||||
tc = message.get("tool_calls")[0]
|
||||
return events.ConversationItem(
|
||||
type="function_call",
|
||||
call_id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
)
|
||||
logger.error(f"Unhandled message type in from_standard_message: {message}")
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
"""Get conversation items for initializing the realtime session history.
|
||||
|
||||
Converts the context's messages to a format suitable for the realtime API,
|
||||
handling system instructions and conversation history packaging.
|
||||
|
||||
Returns:
|
||||
List of conversation items for session initialization.
|
||||
"""
|
||||
# We can't load a long conversation history into the openai realtime api yet. (The API/model
|
||||
# forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So
|
||||
# our general strategy until this is fixed is just to put everything into a first "user"
|
||||
# message as a single input.
|
||||
if not self.messages:
|
||||
return []
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into session
|
||||
# "instructions"
|
||||
if messages[0].get("role") == "system":
|
||||
self.llm_needs_settings_update = True
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
self._session_instructions = content
|
||||
elif isinstance(content, list):
|
||||
self._session_instructions = content[0].get("text")
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# If we have just a single "user" item, we can just send it normally
|
||||
if len(messages) == 1 and messages[0].get("role") == "user":
|
||||
return [self.from_standard_message(messages[0])]
|
||||
|
||||
# Otherwise, let's pack everything into a single "user" message with a bit of
|
||||
# explanation for the LLM
|
||||
intro_text = """
|
||||
This is a previously saved conversation. Please treat this conversation history as a
|
||||
starting point for the current conversation."""
|
||||
|
||||
trailing_text = """
|
||||
This is the end of the previously saved conversation. Please continue the conversation
|
||||
from here. If the last message is a user instruction or question, act on that instruction
|
||||
or answer the question. If the last message is an assistant response, simple say that you
|
||||
are ready to continue the conversation."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "\n\n".join(
|
||||
[intro_text, json.dumps(messages, indent=2), trailing_text]
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def add_user_content_item_as_message(self, item):
|
||||
"""Add a user content item as a standard message to the context.
|
||||
|
||||
Args:
|
||||
item: The conversation item to add as a user message.
|
||||
"""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": item.content[0].transcript}],
|
||||
}
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""User context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles user input frames and generates appropriate context updates
|
||||
for the realtime conversation, including message updates and tool settings.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process incoming frames and handle realtime-specific frame types.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
|
||||
# messages are only processed by the user context aggregator, which is generally what we want. But
|
||||
# we also need to send new messages over the websocket, so the openai realtime API has them
|
||||
# in its context.
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(RealtimeMessagesUpdateFrame(context=self._context))
|
||||
|
||||
# Parent also doesn't push the LLMSetToolsFrame.
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push user input aggregation.
|
||||
|
||||
Currently ignores all user input coming into the pipeline as realtime
|
||||
audio input is handled directly by the service.
|
||||
"""
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Assistant context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles assistant output frames from the realtime service, filtering
|
||||
out duplicate text frames and managing function call results.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# OpenAIRealtimeLLMService also pushes TranscriptionFrames and InterimTranscriptionFrames,
|
||||
# so we need to ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process assistant frames, filtering out duplicate text content.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result and notify the realtime service.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai_realtime.context are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.openai.realtime.context instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,37 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frame types for OpenAI Realtime API integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
import warnings
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
from pipecat.services.openai.realtime.frames import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.openai_realtime.context import OpenAIRealtimeLLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeMessagesUpdateFrame(DataFrame):
|
||||
"""Frame indicating that the realtime context messages have been updated.
|
||||
|
||||
Parameters:
|
||||
context: The updated OpenAI realtime LLM context.
|
||||
"""
|
||||
|
||||
context: "OpenAIRealtimeLLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call results for the realtime service.
|
||||
|
||||
Parameters:
|
||||
result_frame: The function call result frame to send to the realtime API.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai_realtime.frames are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.openai.realtime.frames instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
@@ -99,16 +98,15 @@ class PiperTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||
# remove wav header if present
|
||||
if chunk.startswith(b"RIFF"):
|
||||
chunk = chunk[44:]
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
response.content.iter_chunked(CHUNK_SIZE), strip_wav_header=True
|
||||
):
|
||||
await self.stop_ttfb_metrics()
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
|
||||
@@ -269,6 +269,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
raise ValueError("WebSocket URL is not a string")
|
||||
|
||||
self._websocket = await websocket_connect(self._websocket_url)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except ValueError as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -291,6 +293,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _get_websocket_url(self):
|
||||
"""Retrieve WebSocket URL from PlayHT API."""
|
||||
|
||||
@@ -255,6 +255,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
url = f"{self._url}?{params}"
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -272,6 +274,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
@@ -553,15 +556,13 @@ class RimeHttpTTSService(TTSService):
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||
if need_to_strip_wav_header and chunk.startswith(b"RIFF"):
|
||||
chunk = chunk[44:]
|
||||
need_to_strip_wav_header = False
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
response.content.iter_chunked(CHUNK_SIZE),
|
||||
strip_wav_header=need_to_strip_wav_header,
|
||||
):
|
||||
await self.stop_ttfb_metrics()
|
||||
yield frame
|
||||
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"Rime TTS error: {str(e)}")
|
||||
|
||||
@@ -583,7 +583,9 @@ class RivaSegmentedSTTService(SegmentedSTTService):
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(self, transcript: str, language: Optional[Language] = None):
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -76,17 +76,29 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
Example::
|
||||
|
||||
tts = SarvamTTSService(
|
||||
tts = SarvamHttpTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="anushka",
|
||||
model="bulbul:v2",
|
||||
aiohttp_session=session,
|
||||
params=SarvamTTSService.InputParams(
|
||||
params=SarvamHttpTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
pitch=0.1,
|
||||
pace=1.2
|
||||
)
|
||||
)
|
||||
|
||||
# For bulbul v3 beta with any speaker:
|
||||
tts_v3 = SarvamHttpTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="speaker_name",
|
||||
model="bulbul:v3,
|
||||
aiohttp_session=session,
|
||||
params=SarvamHttpTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
temperature=0.8
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -105,6 +117,14 @@ class SarvamHttpTTSService(TTSService):
|
||||
pace: Optional[float] = Field(default=1.0, ge=0.3, le=3.0)
|
||||
loudness: Optional[float] = Field(default=1.0, ge=0.1, le=3.0)
|
||||
enable_preprocessing: Optional[bool] = False
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.6,
|
||||
ge=0.01,
|
||||
le=1.0,
|
||||
description="Controls the randomness of the output for bulbul v3 beta. "
|
||||
"Lower values make the output more focused and deterministic, while "
|
||||
"higher values make it more random. Range: 0.01 to 1.0. Default: 0.6.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -124,7 +144,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
api_key: Sarvam AI API subscription key.
|
||||
aiohttp_session: Shared aiohttp session for making requests.
|
||||
voice_id: Speaker voice ID (e.g., "anushka", "meera"). Defaults to "anushka".
|
||||
model: TTS model to use ("bulbul:v1" or "bulbul:v2"). Defaults to "bulbul:v2".
|
||||
model: TTS model to use ("bulbul:v2" or "bulbul:v3-beta" or "bulbul:v3"). Defaults to "bulbul:v2".
|
||||
base_url: Sarvam AI API base URL. Defaults to "https://api.sarvam.ai".
|
||||
sample_rate: Audio sample rate in Hz (8000, 16000, 22050, 24000). If None, uses default.
|
||||
params: Additional voice and preprocessing parameters. If None, uses defaults.
|
||||
@@ -138,16 +158,32 @@ class SarvamHttpTTSService(TTSService):
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
|
||||
# Build base settings common to all models
|
||||
self._settings = {
|
||||
"language": (
|
||||
self.language_to_service_language(params.language) if params.language else "en-IN"
|
||||
),
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if model in ("bulbul:v3-beta", "bulbul:v3"):
|
||||
self._settings.update(
|
||||
{
|
||||
"temperature": getattr(params, "temperature", 0.6),
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._settings.update(
|
||||
{
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
@@ -275,6 +311,18 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
pace=1.2
|
||||
)
|
||||
)
|
||||
|
||||
# For bulbul v3 beta with any speaker and temperature:
|
||||
# Note: pace and loudness are not supported for bulbul v3 and bulbul v3 beta
|
||||
tts_v3 = SarvamTTSService(
|
||||
api_key="your-api-key",
|
||||
voice_id="speaker_name",
|
||||
model="bulbul:v3",
|
||||
params=SarvamTTSService.InputParams(
|
||||
language=Language.HI,
|
||||
temperature=0.8
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -310,6 +358,14 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
output_audio_codec: Optional[str] = "linear16"
|
||||
output_audio_bitrate: Optional[str] = "128k"
|
||||
language: Optional[Language] = Language.EN
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.6,
|
||||
ge=0.01,
|
||||
le=1.0,
|
||||
description="Controls the randomness of the output for bulbul v3 beta. "
|
||||
"Lower values make the output more focused and deterministic, while "
|
||||
"higher values make it more random. Range: 0.01 to 1.0. Default: 0.6.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -329,6 +385,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
Args:
|
||||
api_key: Sarvam API key for authenticating TTS requests.
|
||||
model: Identifier of the Sarvam speech model (default "bulbul:v2").
|
||||
Supports "bulbul:v2", "bulbul:v3-beta" and "bulbul:v3".
|
||||
voice_id: Voice identifier for synthesis (default "anushka").
|
||||
url: WebSocket URL for connecting to the TTS backend (default production URL).
|
||||
aiohttp_session: Optional shared aiohttp session. To maintain backward compatibility.
|
||||
@@ -371,15 +428,12 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
self._api_key = api_key
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
# Configuration parameters
|
||||
# Build base settings common to all models
|
||||
self._settings = {
|
||||
"target_language_code": (
|
||||
self.language_to_service_language(params.language) if params.language else "en-IN"
|
||||
),
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"speaker": voice_id,
|
||||
"loudness": params.loudness,
|
||||
"speech_sample_rate": 0,
|
||||
"enable_preprocessing": params.enable_preprocessing,
|
||||
"min_buffer_size": params.min_buffer_size,
|
||||
@@ -387,6 +441,24 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
"output_audio_codec": params.output_audio_codec,
|
||||
"output_audio_bitrate": params.output_audio_bitrate,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if model in ("bulbul:v3-beta", "bulbul:v3"):
|
||||
self._settings.update(
|
||||
{
|
||||
"temperature": getattr(params, "temperature", 0.6),
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._settings.update(
|
||||
{
|
||||
"pitch": params.pitch,
|
||||
"pace": params.pace,
|
||||
"loudness": params.loudness,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
self._started = False
|
||||
|
||||
self._receive_task = None
|
||||
@@ -525,6 +597,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
logger.debug("Connected to Sarvam TTS Websocket")
|
||||
await self._send_config()
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -556,6 +629,10 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
|
||||
@@ -577,6 +577,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
),
|
||||
)
|
||||
logger.debug(f"{self} Connected to Speechmatics STT service")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error connecting to Speechmatics: {e}")
|
||||
self._client = None
|
||||
@@ -595,6 +596,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
logger.error(f"{self} Error closing Speechmatics client: {e}")
|
||||
finally:
|
||||
self._client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _process_config(self) -> None:
|
||||
"""Create a formatted STT transcription config.
|
||||
@@ -618,7 +620,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
transcription_config.additional_vocab = [
|
||||
{
|
||||
"content": e.content,
|
||||
"sounds_like": e.sounds_like,
|
||||
**({"sounds_like": e.sounds_like} if e.sounds_like else {}),
|
||||
}
|
||||
for e in self._params.additional_vocab
|
||||
]
|
||||
|
||||
@@ -35,6 +35,25 @@ class STTService(AIService):
|
||||
Provides common functionality for STT services including audio passthrough,
|
||||
muting, settings management, and audio processing. Subclasses must implement
|
||||
the run_stt method to provide actual speech recognition.
|
||||
|
||||
Event handlers:
|
||||
on_connected: Called when connected to the STT service.
|
||||
on_connected: Called when disconnected from the STT service.
|
||||
on_connection_error: Called when a connection to the STT service error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_connected")
|
||||
async def on_connected(stt: STTService):
|
||||
logger.debug(f"STT connected")
|
||||
|
||||
@stt.event_handler("on_disconnected")
|
||||
async def on_disconnected(stt: STTService):
|
||||
logger.debug(f"STT disconnected")
|
||||
|
||||
@stt.event_handler("on_connection_error")
|
||||
async def on_connection_error(stt: STTService, error: str):
|
||||
logger.error(f"STT connection error: {error}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -62,6 +81,10 @@ class STTService(AIService):
|
||||
self._muted: bool = False
|
||||
self._user_id: str = ""
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
@property
|
||||
def is_muted(self) -> bool:
|
||||
"""Check if the STT service is currently muted.
|
||||
@@ -292,15 +315,6 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
|
||||
Combines STT functionality with websocket connectivity, providing automatic
|
||||
error handling and reconnection capabilities.
|
||||
|
||||
Event handlers:
|
||||
on_connection_error: Called when a websocket connection error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_connection_error")
|
||||
async def on_connection_error(stt: STTService, error: str):
|
||||
logger.error(f"STT connection error: {error}")
|
||||
"""
|
||||
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
@@ -312,7 +326,6 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
"""
|
||||
STTService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user