Compare commits
220 Commits
hush/conte
...
v0.0.98
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9fef78070 | ||
|
|
92970c7873 | ||
|
|
491d298c10 | ||
|
|
c46a20328d | ||
|
|
7e4dbf42e8 | ||
|
|
159e403ae4 | ||
|
|
d3d50ac580 | ||
|
|
e03e5f3a59 | ||
|
|
65e4719cec | ||
|
|
d07b37b288 | ||
|
|
ca97d9dc4b | ||
|
|
4c20483a7e | ||
|
|
6d84f36d05 | ||
|
|
0b6e8f5bca | ||
|
|
cdd6f5aa6a | ||
|
|
f1a0d547ce | ||
|
|
b1b7fc6357 | ||
|
|
b3403e884d | ||
|
|
16e304016d | ||
|
|
21a55f6aae | ||
|
|
310df33de6 | ||
|
|
c8a86059fb | ||
|
|
c537d7bafb | ||
|
|
1fce68cef1 | ||
|
|
ecd9ec4ad2 | ||
|
|
db983cb693 | ||
|
|
5b30f1b1ef | ||
|
|
5f7dbfe775 | ||
|
|
2bb6ba59fc | ||
|
|
ac7b06faba | ||
|
|
afa7573834 | ||
|
|
f2eb9eeb56 | ||
|
|
9e49e09360 | ||
|
|
b5221cd2c1 | ||
|
|
796f3aeff3 | ||
|
|
de94790b94 | ||
|
|
bd3bf9a00e | ||
|
|
92f934031d | ||
|
|
11b92d89d0 | ||
|
|
0d1a122582 | ||
|
|
24b5efb9d8 | ||
|
|
eeb3b85e39 | ||
|
|
8255770b6c | ||
|
|
d3f918eb58 | ||
|
|
36c6549426 | ||
|
|
88d909d468 | ||
|
|
21e346abe2 | ||
|
|
70a80847a7 | ||
|
|
27647fc067 | ||
|
|
85fe6d4c34 | ||
|
|
4cd971e4bd | ||
|
|
54926f390d | ||
|
|
50362ca37e | ||
|
|
a14c911fb2 | ||
|
|
a5e42337a4 | ||
|
|
4f848e9631 | ||
|
|
93df7044fa | ||
|
|
e604e9b490 | ||
|
|
2e4fa3f8db | ||
|
|
5f6448a8a4 | ||
|
|
6cda357ce8 | ||
|
|
7e87f61d17 | ||
|
|
ccdf83800b | ||
|
|
4b81be7acf | ||
|
|
abc2ad8cbc | ||
|
|
64471d65f8 | ||
|
|
3c4991a41f | ||
|
|
71d6516a14 | ||
|
|
22288648e6 | ||
|
|
a6ee040d82 | ||
|
|
87fc860cd5 | ||
|
|
b25ad21941 | ||
|
|
debcea3baa | ||
|
|
c2abe42a64 | ||
|
|
56dee06a29 | ||
|
|
60cc14cafd | ||
|
|
1e98094394 | ||
|
|
ccdd6cde52 | ||
|
|
12979293ad | ||
|
|
28248e9b00 | ||
|
|
0e88ad672e | ||
|
|
f41c3dcbc3 | ||
|
|
645e1802f8 | ||
|
|
6636da682c | ||
|
|
10a32c943f | ||
|
|
455579ffcc | ||
|
|
c37da6ab78 | ||
|
|
1892854516 | ||
|
|
735e597bf2 | ||
|
|
52980a69c5 | ||
|
|
ff2f1dac82 | ||
|
|
3cbfbb997e | ||
|
|
3e66cb50e0 | ||
|
|
b821dd2507 | ||
|
|
0c5bccd1f1 | ||
|
|
926514ca18 | ||
|
|
ca5e668f4a | ||
|
|
53de6c0b9a | ||
|
|
b22ac8292f | ||
|
|
83877ab1e6 | ||
|
|
2a6a0d83db | ||
|
|
6ca117a3c1 | ||
|
|
4fcb099fd7 | ||
|
|
c5ff5cc219 | ||
|
|
88289f578a | ||
|
|
229ff794d6 | ||
|
|
096db3eb6c | ||
|
|
cfd1cada8c | ||
|
|
ee435b6f1e | ||
|
|
d289b38ba7 | ||
|
|
b0f63c3785 | ||
|
|
1249ee3de3 | ||
|
|
b09d8bd595 | ||
|
|
540a48b1b6 | ||
|
|
aa0529ff82 | ||
|
|
7e92597c0e | ||
|
|
99f89351fa | ||
|
|
0b4d984be6 | ||
|
|
17203ba3e6 | ||
|
|
924831089c | ||
|
|
329b8ac426 | ||
|
|
61674d7758 | ||
|
|
b9990811b5 | ||
|
|
8ccc2cbf31 | ||
|
|
f4e33fc8dd | ||
|
|
5bfea84bd5 | ||
|
|
ef703e9d16 | ||
|
|
44aa11737b | ||
|
|
49f1f7d6a2 | ||
|
|
4ea51ff67c | ||
|
|
747bd4f737 | ||
|
|
15f5583fd2 | ||
|
|
c8c6f424cd | ||
|
|
0cdf0c4504 | ||
|
|
217f03b9cc | ||
|
|
12093fcffc | ||
|
|
e5fb643cf5 | ||
|
|
4517475db7 | ||
|
|
92b6e8d66b | ||
|
|
3be1a7afaa | ||
|
|
15df3c06e8 | ||
|
|
f0af0a6b96 | ||
|
|
4cefe1357c | ||
|
|
4df0a9bf73 | ||
|
|
9ef139d020 | ||
|
|
9103d4ae05 | ||
|
|
bd63b6cefa | ||
|
|
4d03270bc3 | ||
|
|
0debcee761 | ||
|
|
6aee72c5b4 | ||
|
|
8d62cfb1b6 | ||
|
|
41214236ab | ||
|
|
b25963a63b | ||
|
|
8c6ef21d84 | ||
|
|
f729b1625b | ||
|
|
0ffaa09c95 | ||
|
|
f6e31b7e89 | ||
|
|
49b2b12e04 | ||
|
|
7ad3969690 | ||
|
|
af089a65ae | ||
|
|
48422dd442 | ||
|
|
fed6a8b669 | ||
|
|
82e0253a62 | ||
|
|
a7f26dca60 | ||
|
|
459ef27f3f | ||
|
|
464cfa5ccb | ||
|
|
9289881a80 | ||
|
|
34033cd454 | ||
|
|
47c21c9579 | ||
|
|
3b0bcf0b66 | ||
|
|
c4a8308027 | ||
|
|
e9f76dcaf2 | ||
|
|
21b2229b2b | ||
|
|
11aa9c9e68 | ||
|
|
9f4680e9bd | ||
|
|
04443a3820 | ||
|
|
1571cc58ac | ||
|
|
dea80cf946 | ||
|
|
91dec044c4 | ||
|
|
8cf4267d87 | ||
|
|
0ee7cab6c6 | ||
|
|
74c2039bfb | ||
|
|
66088837cd | ||
|
|
07ebf8534a | ||
|
|
fce4cfba15 | ||
|
|
af52833ca0 | ||
|
|
9fdf756375 | ||
|
|
283bbb385c | ||
|
|
8c6b2edb25 | ||
|
|
6ab30f9b87 | ||
|
|
3d93285bdf | ||
|
|
7261cd28f2 | ||
|
|
33eeb8ce44 | ||
|
|
ebda94ca98 | ||
|
|
40b17cff8f | ||
|
|
7ba0ebba11 | ||
|
|
b39087027c | ||
|
|
e65974c870 | ||
|
|
b1e5d68d97 | ||
|
|
39bca074d7 | ||
|
|
9dd882ecf8 | ||
|
|
0bbb14eb9b | ||
|
|
1ffa9ff51f | ||
|
|
435b53f1a0 | ||
|
|
406bdfad0d | ||
|
|
7961f8a664 | ||
|
|
4ca143e8af | ||
|
|
0707141998 | ||
|
|
cc861d6b70 | ||
|
|
de4e9c54f6 | ||
|
|
da671cd232 | ||
|
|
1d9696e614 | ||
|
|
afeef94900 | ||
|
|
860d9c4f29 | ||
|
|
4393191166 | ||
|
|
88daad524e | ||
|
|
66c58f8155 | ||
|
|
7bbb5be910 | ||
|
|
0dcb65bd56 | ||
|
|
2784b0f438 |
174
.github/workflows/generate-changelog.yml
vendored
Normal file
174
.github/workflows/generate-changelog.yml
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
name: Generate Changelog for Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "Release version (e.g., 0.0.97)"
|
||||
required: true
|
||||
type: string
|
||||
date:
|
||||
description: "Release date (YYYY-MM-DD format, defaults to today)"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
generate-changelog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Set release date
|
||||
id: set_date
|
||||
run: |
|
||||
if [ -z "${{ inputs.date }}" ]; then
|
||||
RELEASE_DATE=$(date +%Y-%m-%d)
|
||||
echo "Using today's date: $RELEASE_DATE"
|
||||
else
|
||||
RELEASE_DATE="${{ inputs.date }}"
|
||||
echo "Using provided date: $RELEASE_DATE"
|
||||
fi
|
||||
echo "release_date=$RELEASE_DATE" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Validate inputs
|
||||
run: |
|
||||
# Validate version format (basic check)
|
||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+.*$ ]]; then
|
||||
echo "Error: Version must be in format X.Y.Z (e.g., 0.0.97)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate date format if provided
|
||||
if [ -n "${{ inputs.date }}" ]; then
|
||||
if ! date -d "${{ inputs.date }}" >/dev/null 2>&1; then
|
||||
# Try macOS date format
|
||||
if ! date -j -f "%Y-%m-%d" "${{ inputs.date }}" >/dev/null 2>&1; then
|
||||
echo "Error: Date must be in YYYY-MM-DD format (e.g., 2025-12-04)"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Check for changelog fragments
|
||||
id: check_fragments
|
||||
run: |
|
||||
FRAGMENT_COUNT=$(find changelog -name "*.md" ! -name "_template.md.j2" | wc -l | tr -d ' ')
|
||||
echo "fragment_count=$FRAGMENT_COUNT" >> $GITHUB_OUTPUT
|
||||
|
||||
if [ "$FRAGMENT_COUNT" -eq "0" ]; then
|
||||
echo "❌ Error: No changelog fragments found in changelog/"
|
||||
echo ""
|
||||
echo "Cannot create a release without changelog entries."
|
||||
echo "Add changelog fragments to the changelog/ directory (e.g., 1234.added.md) and try again."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate fragment types
|
||||
VALID_TYPES="added changed deprecated removed fixed security"
|
||||
INVALID_FRAGMENTS=""
|
||||
|
||||
for file in changelog/*.md; do
|
||||
# Skip template
|
||||
if [[ "$file" == "changelog/_template.md.j2" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Extract type from filename (e.g., 1234.added.md -> added)
|
||||
filename=$(basename "$file")
|
||||
# Handle both 1234.added.md and 1234.added.2.md patterns
|
||||
type=$(echo "$filename" | sed -E 's/^[0-9]+\.([a-z]+)(\.[0-9]+)?\.md$/\1/')
|
||||
|
||||
# Check if type is valid
|
||||
if ! echo "$VALID_TYPES" | grep -wq "$type"; then
|
||||
INVALID_FRAGMENTS="$INVALID_FRAGMENTS\n - $filename (type: '$type')"
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -n "$INVALID_FRAGMENTS" ]; then
|
||||
echo "❌ Error: Invalid changelog fragment types found:"
|
||||
echo -e "$INVALID_FRAGMENTS"
|
||||
echo ""
|
||||
echo "Valid types are: $VALID_TYPES"
|
||||
echo "Example: 1234.added.md, 5678.fixed.md"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Found $FRAGMENT_COUNT changelog fragment(s)"
|
||||
echo "has_fragments=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Preview changelog
|
||||
run: |
|
||||
echo "## Preview of changelog for version ${{ inputs.version }}"
|
||||
echo ""
|
||||
uv run towncrier build --draft --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}"
|
||||
|
||||
- name: Build changelog
|
||||
run: |
|
||||
uv run towncrier build --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}" --yes
|
||||
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: "Update changelog for version ${{ inputs.version }}"
|
||||
title: "Release ${{ inputs.version }} - Changelog Update"
|
||||
body: |
|
||||
## Changelog Update for Release ${{ inputs.version }}
|
||||
|
||||
This PR updates the CHANGELOG.md with all changes for version **${{ inputs.version }}**.
|
||||
|
||||
### Summary
|
||||
- **Version:** ${{ inputs.version }}
|
||||
- **Date:** ${{ steps.set_date.outputs.release_date }}
|
||||
- **Fragments processed:** ${{ steps.check_fragments.outputs.fragment_count }}
|
||||
|
||||
### What this PR does
|
||||
- ✅ Adds new release section to CHANGELOG.md
|
||||
- ✅ Removes processed changelog fragments
|
||||
- ✅ Ready to merge for release
|
||||
|
||||
### Next Steps
|
||||
1. Review the changelog entries below
|
||||
2. Make any necessary edits to CHANGELOG.md if needed
|
||||
3. Merge this PR
|
||||
4. Continue with your release process
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>📋 Preview of changes</summary>
|
||||
|
||||
The changelog has been updated with entries from the following fragments:
|
||||
|
||||
```bash
|
||||
${{ steps.check_fragments.outputs.fragment_count }} fragments processed
|
||||
```
|
||||
|
||||
</details>
|
||||
branch: changelog-${{ inputs.version }}
|
||||
delete-branch: true
|
||||
labels: |
|
||||
changelog
|
||||
release
|
||||
1
.github/workflows/python-compatibility.yaml
vendored
1
.github/workflows/python-compatibility.yaml
vendored
@@ -50,7 +50,6 @@ jobs:
|
||||
run: |
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra krisp \
|
||||
--no-extra ultravox \
|
||||
--no-extra local-smart-turn \
|
||||
--no-extra moondream \
|
||||
--no-extra mlx-whisper
|
||||
|
||||
@@ -11,7 +11,7 @@ build:
|
||||
jobs:
|
||||
post_install:
|
||||
- pip install uv
|
||||
- UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv sync --group docs --all-extras --no-extra krisp --no-extra gstreamer --no-extra ultravox --no-extra local_smart_turn --no-extra moondream --no-extra riva --no-extra mlx-whisper
|
||||
- UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv sync --group docs --all-extras --no-extra krisp --no-extra gstreamer --no-extra local_smart_turn --no-extra moondream --no-extra riva --no-extra mlx-whisper
|
||||
|
||||
sphinx:
|
||||
configuration: docs/api/conf.py
|
||||
|
||||
326
CHANGELOG.md
326
CHANGELOG.md
@@ -5,21 +5,270 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
<!-- towncrier release notes start -->
|
||||
|
||||
## [0.0.98] - 2025-12-17
|
||||
|
||||
### Added
|
||||
|
||||
- Added `wait_for_all` argument to the base `LLMService`. When enabled, this
|
||||
ensures all function calls complete before returning results to the LLM (i.e.,
|
||||
before running a new inference with those results).
|
||||
- Added `RimeNonJsonTTSService` which supports non-JSON streaming mode. This
|
||||
new class supports websocket streaming for the Arcana model.
|
||||
(PR [#3085](https://github.com/pipecat-ai/pipecat/pull/3085))
|
||||
|
||||
- Added additional functionality related to "thinking", for Google and
|
||||
Anthropic LLMs.
|
||||
|
||||
1. New typed parameters for Google and Anthropic LLMs that control the
|
||||
models' thinking behavior (like how much thinking to do, and whether to
|
||||
output thoughts or thought summaries):
|
||||
- `AnthropicLLMService.ThinkingConfig`
|
||||
- `GoogleLLMService.ThinkingConfig`
|
||||
2. New frames for representing thoughts output by LLMs:
|
||||
- `LLMThoughtStartFrame`
|
||||
- `LLMThoughtTextFrame`
|
||||
- `LLMThoughtEndFrame`
|
||||
3. A generic mechanism for recording LLM thoughts to context, used
|
||||
specifically to support Anthropic, whose thought signatures are expected
|
||||
to appear alongside the text of the thoughts within assistant context
|
||||
messages. See:
|
||||
- `LLMThoughtEndFrame.signature`
|
||||
- `LLMAssistantAggregator` handling of the above field
|
||||
- `AnthropicLLMAdapter` handling of `"thought"` context messages
|
||||
4. Google-specific logic for inserting thought signatures into the context,
|
||||
to help maintain thinking continuity in a chain of LLM calls. See:
|
||||
- `GoogleLLMService` sending `LLMMessagesAppendFrame`s to add
|
||||
LLM-specific
|
||||
`"thought_signature"` messages to context
|
||||
- `GeminiLLMAdapter` handling of `"thought_signature"` messages
|
||||
5. An expansion of `TranscriptProcessor` to process LLM thoughts in
|
||||
addition to user and assistant utterances. See:
|
||||
- `TranscriptProcessor(process_thoughts=True)` (defaults to `False`)
|
||||
- `ThoughtTranscriptionMessage`, which is now also emitted with the
|
||||
`"on_transcript_update"` event
|
||||
(PR [#3175](https://github.com/pipecat-ai/pipecat/pull/3175))
|
||||
|
||||
- Data and control frames can now be marked as non-interruptible by using the
|
||||
`UninterruptibleFrame` mixin. Frames marked as `UninterruptibleFrame` will
|
||||
not be interrupted during processing, and any queued frames of this type will
|
||||
be retained in the internal queues. This is useful when you need ordered
|
||||
frames (data or control) that should not be discarded or cancelled due to
|
||||
interruptions.
|
||||
(PR [#3189](https://github.com/pipecat-ai/pipecat/pull/3189))
|
||||
|
||||
- Added `on_conversation_detected` event to `VoicemaiDetector`.
|
||||
(PR [#3207](https://github.com/pipecat-ai/pipecat/pull/3207))
|
||||
|
||||
- Added `x-goog-api-client` header with Pipecat's version to all Google
|
||||
services' requests.
|
||||
(PR [#3208](https://github.com/pipecat-ai/pipecat/pull/3208))
|
||||
|
||||
- Added support for the HeyGen LiveAvatar API (see https://www.liveavatar.com/).
|
||||
(PR [#3210](https://github.com/pipecat-ai/pipecat/pull/3210))
|
||||
|
||||
- Added to `AWSNovaSonicLLMService` functionality related to the new (and now
|
||||
default) Nova 2 Sonic model (`"amazon.nova-2-sonic-v1:0"`):
|
||||
|
||||
- Added the `endpointing_sensitivity` parameter to control how quickly the
|
||||
model decides the user has stopped speaking.
|
||||
- Made the assistant-response-trigger hack a no-op. It's only needed for
|
||||
the older Nova Sonic model.
|
||||
(PR [#3212](https://github.com/pipecat-ai/pipecat/pull/3212))
|
||||
|
||||
- [Ultravox Realtime](https://docs.ultravox.ai) is now a supported
|
||||
speech-to-speech service.
|
||||
|
||||
- Added `UltravoxRealtimeLLMService` for the integration.
|
||||
- Added `49-ultravox-realtime.py` example (with tool calling).
|
||||
(PR [#3227](https://github.com/pipecat-ai/pipecat/pull/3227))
|
||||
|
||||
- Added Daily PSTN dial-in support to the development runner with `--dialin`
|
||||
flag. This includes:
|
||||
|
||||
- `/daily-dialin-webhook` endpoint that handles incoming Daily PSTN webhooks
|
||||
- Automatic Daily room creation with SIP configuration
|
||||
- `DialinSettings` and `DailyDialinRequest` types in `pipecat.runner.types`
|
||||
for type-safe dial-in data
|
||||
- The runner now mimics Pipecat Cloud's dial-in webhook handling for local
|
||||
development
|
||||
(PR [#3235](https://github.com/pipecat-ai/pipecat/pull/3235))
|
||||
|
||||
- Add Gladia session id to logs for `GladiaSTTService`.
|
||||
(PR [#3236](https://github.com/pipecat-ai/pipecat/pull/3236))
|
||||
|
||||
- Added `InworldHttpTTSService` which uses Inworld's HTTP based TTS service in
|
||||
either streaming or non-streaming mode. Note: This class was previously named
|
||||
`InworldTTSService`.
|
||||
(PR [#3239](https://github.com/pipecat-ai/pipecat/pull/3239))
|
||||
|
||||
- Added `language_hints_strict` parameter to `SonioxSTTService` to strictly
|
||||
enforces language hints. This ensures that transcription occurs in the
|
||||
specified language.
|
||||
(PR [#3245](https://github.com/pipecat-ai/pipecat/pull/3245))
|
||||
|
||||
- Added Pipecat library version info to the `about` field in the `bot-ready`
|
||||
RTVI message.
|
||||
(PR [#3248](https://github.com/pipecat-ai/pipecat/pull/3248))
|
||||
|
||||
- Added `VisionFullResponseStartFrame`, `VisionFullResponseEndFrame` and
|
||||
`VisionTextFrame`. This are used by vision services similar to LLM
|
||||
services.
|
||||
(PR [#3252](https://github.com/pipecat-ai/pipecat/pull/3252))
|
||||
|
||||
### Changed
|
||||
|
||||
- Improved interruption handling to prevent bots from repeating themselves.
|
||||
LLM services that return multiple sentences in a single response (e.g.,
|
||||
`GoogleLLMService`) are now split into individual sentences before being sent
|
||||
to TTS. This ensures interruptions occur at sentence boundaries, preventing
|
||||
the bot from repeating content after being interrupted during long responses.
|
||||
- `FunctionCallInProgressFrame` and `FunctionCallResultFrame` have changed from
|
||||
system frames to a control frame and a data frame, respectively, and are
|
||||
now both marked as `UninterruptibleFrame`.
|
||||
(PR [#3189](https://github.com/pipecat-ai/pipecat/pull/3189))
|
||||
|
||||
- `UserBotLatencyLogObserver` now uses `VADUserStartedSpeakingFrame` and
|
||||
`VADUserStoppedSpeakingFrame` to determine latency from user stopped speaking
|
||||
to bot started speaking.
|
||||
(PR [#3206](https://github.com/pipecat-ai/pipecat/pull/3206))
|
||||
|
||||
- Updated `HeyGenVideoService` and `HeyGenTransport` to support both HeyGen
|
||||
APIs (Interactive Avatar and Live Avatar).
|
||||
Using them is as simple as specifying the `service_type` when creating the
|
||||
`HeyGenVideoService` and the `HeyGenTransport`:
|
||||
|
||||
```python
|
||||
heyGen = HeyGenVideoService(
|
||||
api_key=os.getenv("HEYGEN_LIVE_AVATAR_API_KEY"),
|
||||
service_type=ServiceType.LIVE_AVATAR,
|
||||
session=session,
|
||||
)
|
||||
```
|
||||
|
||||
(PR [#3210](https://github.com/pipecat-ai/pipecat/pull/3210))
|
||||
|
||||
- Made `"amazon.nova-2-sonic-v1:0"` the new default model for
|
||||
`AWSNovaSonicLLMService`.
|
||||
(PR [#3212](https://github.com/pipecat-ai/pipecat/pull/3212))
|
||||
|
||||
- Updated the `run_inference` methods in the LLM service classes
|
||||
(`AnthropicLLMService`, `AWSBedrockLLMService`, `GoogleLLMService`, and
|
||||
`OpenAILLMService` and its base classes) to use the provided LLM
|
||||
configuration parameters.
|
||||
(PR [#3214](https://github.com/pipecat-ai/pipecat/pull/3214))
|
||||
|
||||
- Updated default models for:
|
||||
|
||||
- `GeminiLiveLLMService` to `gemini-2.5-flash-native-audio-preview-12-2025`.
|
||||
- `GeminiLiveVertexLLMService` to `gemini-live-2.5-flash-native-audio`.
|
||||
(PR [#3228](https://github.com/pipecat-ai/pipecat/pull/3228))
|
||||
|
||||
- Changed the `reason` field in `EndFrame`, `CancelFrame`, `EndTaskFrame`, and
|
||||
`CancelTaskFrame` from `str` to `Any` to indicate that it can hold values
|
||||
other than strings.
|
||||
(PR [#3231](https://github.com/pipecat-ai/pipecat/pull/3231))
|
||||
|
||||
- Updated websocket STT services to use the `WebsocketSTTService` base class.
|
||||
This base class manages the websocket connection and handles reconnects.
|
||||
Updated services:
|
||||
|
||||
- `AssemblyAISTTService`
|
||||
- `AWSTranscribeSTTService`
|
||||
- `GladiaSTTService`
|
||||
- `SonioxSTTService`
|
||||
(PR [#3236](https://github.com/pipecat-ai/pipecat/pull/3236))
|
||||
|
||||
- Changed Inworld's TTS service implementations:
|
||||
|
||||
- Previously, the HTTP implementation was named `InworldTTSService`. That
|
||||
has been moved to `InworldHttpTTSService`. This service now supports
|
||||
word-timestamp alignment data in both streaming and non-streaming modes.
|
||||
- Updated the `InworldTTSService` class to use Inworld's Websocket API.
|
||||
This class now has support for word-timestamp alignment data and tracks
|
||||
contexts for each user turn.
|
||||
(PR [#3239](https://github.com/pipecat-ai/pipecat/pull/3239))
|
||||
|
||||
- ⚠️ Breaking change: `WordTTSService.start_word_timestamps()` and
|
||||
`WordTTSService.reset_word_timestamps()` are now async.
|
||||
(PR [#3240](https://github.com/pipecat-ai/pipecat/pull/3240))
|
||||
|
||||
- Updated the current RTVI version to 1.1.0 to reflect recent additions and
|
||||
deprecations.
|
||||
|
||||
- New RTVI Messages: `send-text` and `bot-output`
|
||||
- Deprecated Messages: `append-to-context` and `bot-transcription`
|
||||
(PR [#3248](https://github.com/pipecat-ai/pipecat/pull/3248))
|
||||
|
||||
- `MoondreamService` now pushes `VisionFullResponseStartFrame`,
|
||||
`VisionFullResponseEndFrame` and `VisionTextFrame`.
|
||||
(PR [#3252](https://github.com/pipecat-ai/pipecat/pull/3252))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `FalSmartTurnAnalyzer` and `LocalSmartTurnAnalyzer` are deprecated and will
|
||||
be removed in a future version. Use `LocalSmartTurnAnalyzerV3` instead.
|
||||
(PR [#3219](https://github.com/pipecat-ai/pipecat/pull/3219))
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed the deprecated VLLM-based open source Ultravox STT service.
|
||||
(PR [#3227](https://github.com/pipecat-ai/pipecat/pull/3227))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug in `AWSNovaSonicLLMService` where we would mishandle cancelled
|
||||
tool calls in the context, resulting in errors.
|
||||
(PR [#3212](https://github.com/pipecat-ai/pipecat/pull/3212))
|
||||
|
||||
- Better support conversation history with Gemini 2.5 Flash Image (model
|
||||
"gemini-2.5-flash-image"). Prior to this fix, the model had no memory of
|
||||
previous images it had generated, so it wouldn't be able to iterate on
|
||||
them.
|
||||
(PR [#3224](https://github.com/pipecat-ai/pipecat/pull/3224))
|
||||
|
||||
- Support conversations with Gemini 3 Pro Image (model
|
||||
"gemini-3-pro-image-preview"). Prior to this fix, after the model generated
|
||||
an image the conversation would not be able to progress.
|
||||
(PR [#3224](https://github.com/pipecat-ai/pipecat/pull/3224))
|
||||
|
||||
- Fixed an issue where `ElevenLabsHttpTTSService` was not updating
|
||||
voice settings when receiving a `TTSUpdateSettingsFrame`.
|
||||
(PR [#3226](https://github.com/pipecat-ai/pipecat/pull/3226))
|
||||
|
||||
- Fixed the return type for `SmallWebRTCRequestHandler.handle_web_request()`
|
||||
function.
|
||||
(PR [#3230](https://github.com/pipecat-ai/pipecat/pull/3230))
|
||||
|
||||
- Fix a bug in LLM context audio content handling
|
||||
(PR [#3234](https://github.com/pipecat-ai/pipecat/pull/3234))
|
||||
|
||||
- In `GladiaSTTService`, reset the `_bytes_sent` counter on connecting the
|
||||
websocket. This avoids unnecessary audio buffer trimming.
|
||||
(PR [#3236](https://github.com/pipecat-ai/pipecat/pull/3236))
|
||||
|
||||
- Fixed a TTS service word-timestamp issue that could cause generated
|
||||
`TTSTextFrame` instances to have an incorrect pts (`pts = -1`).
|
||||
(PR [#3240](https://github.com/pipecat-ai/pipecat/pull/3240))
|
||||
|
||||
- Fixed an issue in `SimpleTextAggreagtor` where spaces were not being stripped
|
||||
before returning the aggregation. This resulted in an extra space for TTS
|
||||
services that don't support word-timestamp alignment data.
|
||||
(PR [#3247](https://github.com/pipecat-ai/pipecat/pull/3247))
|
||||
|
||||
## [0.0.97] - 2025-12-05
|
||||
|
||||
### Added
|
||||
|
||||
- Added new Gradium services, `GradiumSTTService` and `GradiumTTSService`, for
|
||||
speech-to-text and text-to-speech functionality using Gradium's API.
|
||||
|
||||
- Additions for `AsyncAITTSService` and `AsyncAIHttpTTSService`:
|
||||
|
||||
- Added new `languages`: `pt`, `nl`, `ar`, `ru`, `ro`, `ja`, `he`, `hy`,
|
||||
`tr`, `hi`, `zh`.
|
||||
- Updated the default model to `asyncflow_multilingual_v1.0` for improved
|
||||
accuracy and broader language coverage.
|
||||
|
||||
- Added optional tool and tool output filters for MCP services.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Deepgram logging to include Deepgram request IDs for improved
|
||||
debugging.
|
||||
|
||||
- Text Aggregation Improvements:
|
||||
|
||||
@@ -31,21 +280,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
|
||||
the base class's sentence detection logic.
|
||||
|
||||
- Improved interruption handling to prevent bots from repeating themselves. LLM
|
||||
services that return multiple sentences in a single response (e.g.,
|
||||
`GoogleLLMService`) are now split into individual sentences before being sent
|
||||
to TTS. This ensures interruptions occur at sentence boundaries, preventing
|
||||
the bot from repeating content after being interrupted during long responses.
|
||||
|
||||
- Updated `AICFilter` to use Quail STT as the default model
|
||||
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
|
||||
interaction (e.g., voice agents, speech-to-text) and operates at a native
|
||||
sample rate of 16 kHz with fixed enhancement parameters.
|
||||
|
||||
- Updated Deepgram logging to include Deepgram request IDs for improved debugging.
|
||||
- If an unexpected exception is caught, or if `FrameProcessor.push_error()` is
|
||||
called with an exception, the file name and line number where the exception
|
||||
occured are now logged.
|
||||
|
||||
- Updated Smart Turn model weights to v3.1.
|
||||
|
||||
- Smart Turn analyzer now uses the full context of the turn rather than just
|
||||
the audio since VAD last triggered.
|
||||
|
||||
- Updated `CartesiaSTTService` to return the full transcription `result` in the
|
||||
`TranscriptionFrame` and `InterimTranscriptionFrame`. This provides access to
|
||||
word timestamp data.
|
||||
|
||||
- `HumeTTSService` changes:
|
||||
|
||||
- Added tracking headers (`X-Hume-Client-Name` and `X-Hume-Client-Version`)
|
||||
to all requests made by `HumeTTSService` to the Hume API for better usage
|
||||
tracking and analytics.
|
||||
- Added `stop()` and `cancel()` cleanup methods to `HumeTTSService` to
|
||||
properly close the HTTP client and prevent resource leaks.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
|
||||
|
||||
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
|
||||
has any effect. Noise gating is now handled automatically by the AIC VAD
|
||||
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
|
||||
|
||||
- NVIDIA Services name changes (all functionality is unchanged):
|
||||
|
||||
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
|
||||
@@ -54,29 +322,41 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Use `uv pip install pipecat-ai[nvidia]` instead of
|
||||
`uv pip install pipecat-ai[riva]`
|
||||
|
||||
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
|
||||
has any effect. Noise gating is now handled automatically by the AIC VAD
|
||||
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
|
||||
|
||||
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
|
||||
services.
|
||||
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
|
||||
multiple times for `KEEP` or `AGGREGATE` patterns.
|
||||
|
||||
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
|
||||
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
|
||||
|
||||
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
|
||||
multiple times for `KEEP` or `AGGREGATE` patterns.
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was always
|
||||
set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
|
||||
spoken. Now, audio is flushed when the TTS services receives the
|
||||
`LLMFullResponseEndFrame` or `EndFrame`.
|
||||
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was
|
||||
always set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
|
||||
incorrectly pushed after a functional call. This caused an issue with the
|
||||
voice-ui-kit's conversational panel rending of the LLM output after a
|
||||
function call.
|
||||
|
||||
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
|
||||
services.
|
||||
|
||||
- Fixed an issue that caused `WebsocketService` instances to attempt
|
||||
reconnection during shutdown.
|
||||
|
||||
- Fixed an issue in `ElevenLabsTTSService` where character usage metrics were
|
||||
only reported on the first TTS generation per turn.
|
||||
|
||||
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃
|
||||
|
||||
### Added
|
||||
|
||||
105
CONTRIBUTING.md
105
CONTRIBUTING.md
@@ -17,24 +17,121 @@ We welcome contributions of all kinds! Your help is appreciated. Follow these st
|
||||
git checkout -b your-branch-name
|
||||
```
|
||||
4. **Make your changes**: Edit or add files as necessary.
|
||||
5. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
|
||||
6. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
|
||||
5. **Add a changelog entry**: Create a changelog fragment file (see [Changelog Entries](#changelog-entries) below).
|
||||
6. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
|
||||
7. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
|
||||
|
||||
```bash
|
||||
git commit -m "Description of your changes"
|
||||
```
|
||||
|
||||
7. **Push your changes**: Push your branch to your forked repository.
|
||||
8. **Push your changes**: Push your branch to your forked repository.
|
||||
|
||||
```bash
|
||||
git push origin your-branch-name
|
||||
```
|
||||
|
||||
8. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
|
||||
9. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
|
||||
> Important: Describe the changes you've made clearly!
|
||||
|
||||
Our maintainers will review your PR, and once everything is good, your contributions will be merged!
|
||||
|
||||
## Changelog Entries
|
||||
|
||||
Every pull request that makes a user-facing change should include a changelog entry. We use a changelog fragment system to avoid merge conflicts.
|
||||
|
||||
### Creating a Changelog Fragment
|
||||
|
||||
1. Create a new file in the `changelog/` directory with this naming pattern:
|
||||
|
||||
```
|
||||
<PR_number>.<type>.md
|
||||
```
|
||||
|
||||
2. Choose the appropriate type:
|
||||
|
||||
- `added.md` - New features
|
||||
- `changed.md` - Changes in existing functionality
|
||||
- `deprecated.md` - Soon-to-be removed features
|
||||
- `removed.md` - Removed features
|
||||
- `fixed.md` - Bug fixes
|
||||
- `security.md` - Security fixes
|
||||
|
||||
3. Write your changelog entry as a Markdown bullet point. Include the `-` at the start:
|
||||
|
||||
**Example files:**
|
||||
|
||||
`changelog/1234.added.md`:
|
||||
|
||||
```markdown
|
||||
- Added support for Anthropic Claude 3.5 Sonnet with improved streaming performance.
|
||||
```
|
||||
|
||||
`changelog/5678.fixed.md`:
|
||||
|
||||
```markdown
|
||||
- Fixed an issue where audio frames were dropped during high-load scenarios.
|
||||
```
|
||||
|
||||
**For entries with nested bullets:**
|
||||
|
||||
`changelog/1234.changed.md`:
|
||||
|
||||
```markdown
|
||||
- Updated service configuration:
|
||||
|
||||
- Changed default timeout to 30 seconds
|
||||
- Added retry logic for failed connections
|
||||
```
|
||||
|
||||
### Multiple Changes in One PR
|
||||
|
||||
**Different types of changes:** Create separate fragment files for each type:
|
||||
|
||||
```
|
||||
changelog/1234.added.md
|
||||
changelog/1234.fixed.md
|
||||
```
|
||||
|
||||
**Multiple changes of the same type:** Create numbered fragment files:
|
||||
|
||||
```
|
||||
changelog/1234.changed.md
|
||||
changelog/1234.changed.2.md
|
||||
```
|
||||
|
||||
**Related changes:** Use nested bullets in a single fragment:
|
||||
|
||||
```markdown
|
||||
- Updated service configuration:
|
||||
|
||||
- Changed default timeout to 30 seconds
|
||||
- Added retry logic for failed connections
|
||||
```
|
||||
|
||||
**Rule of thumb:** One logical change per fragment file. If changes are unrelated, use separate files.
|
||||
|
||||
### Preview Your Changes
|
||||
|
||||
To see what your changelog entry will look like:
|
||||
|
||||
```bash
|
||||
towncrier build --draft --version Unreleased
|
||||
```
|
||||
|
||||
This won't modify any files, just show you a preview.
|
||||
|
||||
### When to Skip Changelog Entries
|
||||
|
||||
You can skip adding a changelog entry for:
|
||||
|
||||
- Documentation-only changes
|
||||
- Internal refactoring with no user-facing impact
|
||||
- Test-only changes
|
||||
- CI/build configuration changes
|
||||
|
||||
If you're unsure whether your change needs a changelog entry, ask in your PR!
|
||||
|
||||
## Dependency Management
|
||||
|
||||
This project uses [uv](https://docs.astral.sh/uv/) for dependency management. The `uv.lock` file is committed to ensure reproducible builds.
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
</div></h1>
|
||||
|
||||
[](https://pypi.org/project/pipecat-ai)  [](https://codecov.io/gh/pipecat-ai/pipecat) [](https://docs.pipecat.ai) [](https://discord.gg/pipecat) [](https://deepwiki.com/pipecat-ai/pipecat)
|
||||
[](https://getmanta.ai/pipecat)
|
||||
|
||||
# 🎙️ Pipecat: Real-Time Voice & Multimodal AI Agents
|
||||
|
||||
@@ -74,10 +73,10 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai), Ultravox, |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
@@ -154,7 +153,6 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra ultravox # (ultravox not fully supported on macOS)
|
||||
```
|
||||
|
||||
3. Install the git pre-commit hooks:
|
||||
|
||||
16
changelog/_template.md.j2
Normal file
16
changelog/_template.md.j2
Normal file
@@ -0,0 +1,16 @@
|
||||
{% for section, _ in sections.items() %}
|
||||
{% if sections[section] %}
|
||||
{% for category, val in definitions.items() if category in sections[section]%}
|
||||
### {{ definitions[category]['name'] }}
|
||||
|
||||
{% for text, values in sections[section][category].items() %}
|
||||
{{ text }}
|
||||
(PR {{ values|join(', ') }})
|
||||
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
No significant changes.
|
||||
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# Build docs using uv
|
||||
echo "Installing dependencies with uv..."
|
||||
uv sync --group docs --all-extras --no-extra krisp --no-extra gstreamer --no-extra ultravox --no-extra local_smart_turn --no-extra moondream --no-extra riva --no-extra mlx-whisper
|
||||
uv sync --group docs --all-extras --no-extra krisp --no-extra gstreamer --no-extra local_smart_turn --no-extra moondream --no-extra riva --no-extra mlx-whisper
|
||||
|
||||
# Check if sphinx-build is available
|
||||
if ! uv run sphinx-build --version &> /dev/null; then
|
||||
@@ -24,4 +24,4 @@ if [ $? -eq 0 ]; then
|
||||
else
|
||||
echo "Documentation build failed!" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -61,9 +61,6 @@ autodoc_mock_imports = [
|
||||
# OpenCV - sometimes has import issues during docs build
|
||||
"cv2",
|
||||
# Heavy ML packages excluded from ReadTheDocs
|
||||
# ultravox dependencies
|
||||
"vllm",
|
||||
"vllm.engine.arg_utils",
|
||||
# local-smart-turn dependencies
|
||||
"coremltools",
|
||||
"coremltools.models",
|
||||
|
||||
@@ -73,6 +73,9 @@ GOOGLE_CLOUD_PROJECT_ID=...
|
||||
GOOGLE_CLOUD_LOCATION=...
|
||||
GOOGLE_TEST_CREDENTIALS=...
|
||||
|
||||
# Gradium
|
||||
GRAPDIUM_API_KEY=...
|
||||
|
||||
# Grok
|
||||
GROK_API_KEY=...
|
||||
|
||||
@@ -81,6 +84,7 @@ GROQ_API_KEY=...
|
||||
|
||||
# Heygen
|
||||
HEYGEN_API_KEY=...
|
||||
HEYGEN_LIVE_AVATAR_API_KEY=...
|
||||
|
||||
# Hume
|
||||
HUME_API_KEY=...
|
||||
@@ -187,8 +191,11 @@ TOGETHER_API_KEY=...
|
||||
TWILIO_ACCOUNT_SID=...
|
||||
TWILIO_AUTH_TOKEN=...
|
||||
|
||||
# Ultravox Realtime
|
||||
ULTRAVOX_API_KEY=...
|
||||
|
||||
# WhatsApp
|
||||
WHATSAPP_TOKEN=...
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN=...
|
||||
WHATSAPP_PHONE_NUMBER_ID=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
@@ -15,26 +14,26 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.inworld.tts import InworldTTSService
|
||||
from pipecat.services.inworld.tts import InworldHttpTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
@@ -58,22 +57,18 @@ transport_params = {
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
logger.info("Starting bot")
|
||||
|
||||
# Create an HTTP session
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
# Inworld TTS Service - Unified streaming and non-streaming
|
||||
# Set streaming=True for real-time audio, streaming=False for complete audio generation
|
||||
streaming = True # Toggle this to switch between modes
|
||||
|
||||
tts = InworldTTSService(
|
||||
tts = InworldHttpTTSService(
|
||||
api_key=os.getenv("INWORLD_API_KEY", ""),
|
||||
aiohttp_session=session,
|
||||
voice_id="Ashley",
|
||||
model="inworld-tts-1",
|
||||
streaming=streaming, # True: real-time chunks, False: complete audio then playback
|
||||
# Set to False for non-streaming mode or True for streaming mode.
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
@@ -81,22 +76,25 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are very knowledgable about dogs. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
"content": "You are a helpful AI demonstrating Inworld AI's TTS. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a friendly and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
rtvi = RTVIProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transport.input(),
|
||||
rtvi,
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -106,19 +104,27 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
}
|
||||
),
|
||||
],
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info("Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
141
examples/foundational/07ab-interruptible-inworld.py
Normal file
141
examples/foundational/07ab-interruptible-inworld.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.inworld.tts import InworldTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = InworldTTSService(
|
||||
api_key=os.getenv("INWORLD_API_KEY", ""),
|
||||
voice_id="Ashley",
|
||||
model="inworld-tts-1",
|
||||
temperature=1.1,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI demonstrating Inworld AI's TTS. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a friendly and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
rtvi,
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
}
|
||||
),
|
||||
],
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -14,32 +13,23 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.ultravox.stt import UltravoxSTTService
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.services.gradium.tts import GradiumTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# NOTE: This example requires GPU resources to run efficiently.
|
||||
# The Ultravox model is compute-intensive and performs best with GPU acceleration.
|
||||
# This can be deployed on cloud GPU providers like Cerebrium.ai for optimal performance.
|
||||
|
||||
|
||||
# Want to initialize the ultravox processor since it takes time to load the model and dont
|
||||
# want to load it every time the pipeline is run
|
||||
ultravox_processor = UltravoxSTTService(
|
||||
model_name="fixie-ai/ultravox-v0_5-llama-3_1-8b",
|
||||
hf_token=os.getenv("HF_TOKEN"),
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -68,17 +58,34 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.environ.get("CARTESIA_API_KEY"),
|
||||
voice_id="97f4b8fb-f2fe-444b-bb9a-c109783a857a",
|
||||
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
|
||||
|
||||
tts = GradiumTTSService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
voice_id="YTpq7expH9539ERJ",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
ultravox_processor,
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -94,6 +101,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
@@ -71,9 +71,9 @@ def build_agent(model_id: str, max_tokens: int):
|
||||
@tool
|
||||
def check_weather(location: str) -> str:
|
||||
if location.lower() == "san francisco":
|
||||
return "The weather in San Francisco is sunny and 30 degrees."
|
||||
return "The weather in San Francisco is sunny and 75 degrees."
|
||||
elif location.lower() == "sydney":
|
||||
return "The weather in Sydney is cloudy and 20 degrees."
|
||||
return "The weather in Sydney is cloudy and 60 degrees."
|
||||
else:
|
||||
return "I'm not sure about the weather in that location."
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash-image",
|
||||
# model="gemini-3-pro-image-preview", # A more powerful model, but slower
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -136,7 +136,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Hello! I'm your AI assistant. I can help you with a variety of tasks. What would you like to know?",
|
||||
"content": "You are an AI assistant. You can help with a variety of tasks. Introduce yourself and ask the user what they would like to know.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@@ -75,8 +75,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash",
|
||||
# turn on thinking if you want it
|
||||
# params=GoogleLLMService.InputParams(extra={"thinking_config": {"thinking_budget": 4096}}),)
|
||||
# force a certain amount of thinking if you want it
|
||||
# params=GoogleLLMService.InputParams(
|
||||
# thinking=GoogleLLMService.ThinkingConfig(thinking_budget=4096)
|
||||
# ),
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -75,8 +75,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash",
|
||||
# turn on thinking if you want it
|
||||
# params=GoogleLLMService.InputParams(extra={"thinking_config": {"thinking_budget": 4096}}),)
|
||||
# force a certain amount of thinking if you want it
|
||||
# params=GoogleLLMService.InputParams(
|
||||
# thinking=GoogleLLMService.ThinkingConfig(thinking_budget=4096)
|
||||
# ),
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -224,8 +224,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash",
|
||||
# turn on thinking if you want it
|
||||
# params=GoogleLLMService.InputParams(extra={"thinking_config": {"thinking_budget": 4096}}),
|
||||
# force a certain amount of thinking if you want it
|
||||
# params=GoogleLLMService.InputParams(
|
||||
# thinking=GoogleLLMService.ThinkingConfig(thinking_budget=4096)
|
||||
# ),
|
||||
)
|
||||
|
||||
tts = GoogleTTSService(
|
||||
|
||||
@@ -76,7 +76,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = FireworksLLMService(
|
||||
api_key=os.getenv("FIREWORKS_API_KEY"),
|
||||
model="accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
model="accounts/fireworks/models/gpt-oss-20b",
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
|
||||
@@ -17,7 +17,6 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
|
||||
@@ -20,7 +20,6 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
|
||||
@@ -18,7 +18,6 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
|
||||
@@ -64,11 +64,14 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@@ -88,6 +91,23 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# full list of tools available from rijksmuseum MCP:
|
||||
# - get_artwork_details
|
||||
# - get_artwork_image
|
||||
# - get_user_sets
|
||||
# - get_user_set_details
|
||||
# - open_image_in_browser
|
||||
# - get_artist_timeline
|
||||
|
||||
mcp_tools_filter = ["get_artwork_details", "get_artwork_image", "open_image_in_browser"]
|
||||
|
||||
|
||||
def open_image_output_filter(output: str):
|
||||
pattern = r"Successfully opened image in browser: "
|
||||
text_to_print = re.sub(pattern, "", output)
|
||||
print(f"🖼️ link to high resolution artwork: {text_to_print}")
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -136,7 +156,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
),
|
||||
# Optional
|
||||
tools_filter=mcp_tools_filter, # Optional
|
||||
tools_output_filters={"open_image_in_browser": open_image_output_filter},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
|
||||
@@ -67,13 +67,14 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
return None
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except:
|
||||
pass
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
try:
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -33,11 +35,21 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
temperature = (
|
||||
random.randint(60, 85)
|
||||
if params.arguments["format"] == "fahrenheit"
|
||||
else random.randint(15, 30)
|
||||
)
|
||||
# Simulate a long network delay.
|
||||
# You can continue chatting while waiting for this to complete.
|
||||
# With Nova 2 Sonic (the default model), the assistant will respond
|
||||
# appropriately once the function call is complete.
|
||||
await asyncio.sleep(5)
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"location": params.arguments["location"],
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
@@ -91,23 +103,31 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Specify initial system instruction.
|
||||
# HACK: note that, for now, we need to inject a special bit of text into this instruction to
|
||||
# allow the first assistant response to be programmatically triggered (which happens in the
|
||||
# on_client_connected handler, below)
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken dialog exchanging "
|
||||
"the transcripts of a natural real-time conversation. Keep your responses short, generally "
|
||||
"two or three sentences for chatty scenarios. "
|
||||
f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}"
|
||||
"two or three sentences for chatty scenarios."
|
||||
# HACK: if using the older Nova Sonic (pre-2) model, note that you need to inject a special
|
||||
# bit of text into this instruction to allow the first assistant response to be
|
||||
# programmatically triggered (which happens in the on_client_connected handler)
|
||||
# f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}"
|
||||
)
|
||||
|
||||
# Create the AWS Nova Sonic LLM service
|
||||
llm = AWSNovaSonicLLMService(
|
||||
secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
region=os.getenv("AWS_REGION"), # as of 2025-05-06, us-east-1 is the only supported region
|
||||
# as of 2025-12-09, these are the supported regions:
|
||||
# - Nova 2 Sonic (the default model):
|
||||
# - us-east-1
|
||||
# - us-west-2
|
||||
# - ap-northeast-1
|
||||
# - Nova Sonic (the older model):
|
||||
# - us-east-1
|
||||
# - ap-northeast-1
|
||||
region=os.getenv("AWS_REGION"),
|
||||
session_token=os.getenv("AWS_SESSION_TOKEN"),
|
||||
voice_id="tiffany", # matthew, tiffany, amy
|
||||
voice_id="tiffany",
|
||||
# you could choose to pass instruction here rather than via context
|
||||
# system_instruction=system_instruction
|
||||
# you could choose to pass tools here rather than via context
|
||||
@@ -117,7 +137,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Register function for function calls
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function(
|
||||
"get_current_weather", fetch_weather_from_api, cancel_on_interruption=False
|
||||
)
|
||||
|
||||
# Set up context and context management.
|
||||
context = LLMContext(
|
||||
@@ -159,10 +181,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
# HACK: for now, we need this special way of triggering the first assistant response in AWS
|
||||
# Nova Sonic. Note that this trigger requires a special corresponding bit of text in the
|
||||
# system instruction. In the future, simply queueing the context frame should be sufficient.
|
||||
await llm.trigger_assistant_response()
|
||||
# HACK: if using the older Nova Sonic (pre-2) model, you need this special way of
|
||||
# triggering the first assistant response. Note that this trigger requires a special
|
||||
# corresponding bit of text in the system instruction.
|
||||
# await llm.trigger_assistant_response()
|
||||
|
||||
# Handle client disconnection events
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
|
||||
@@ -25,7 +25,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.heygen.api import AvatarQuality, NewSessionRequest
|
||||
from pipecat.services.heygen.client import ServiceType
|
||||
from pipecat.services.heygen.video import HeyGenVideoService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams, DailyTransport
|
||||
@@ -73,11 +73,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
|
||||
heyGen = HeyGenVideoService(
|
||||
api_key=os.getenv("HEYGEN_API_KEY"),
|
||||
api_key=os.getenv("HEYGEN_LIVE_AVATAR_API_KEY"),
|
||||
service_type=ServiceType.LIVE_AVATAR,
|
||||
session=session,
|
||||
session_request=NewSessionRequest(
|
||||
avatar_id="Shawn_Therapist_public", version="v2", quality=AvatarQuality.high
|
||||
),
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -113,8 +113,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
@voicemail.event_handler("on_conversation_detected")
|
||||
async def on_conversation_detected(processor):
|
||||
logger.info("Conversation detected!")
|
||||
|
||||
@voicemail.event_handler("on_voicemail_detected")
|
||||
async def handle_voicemail(processor):
|
||||
async def on_voicemail_detected(processor):
|
||||
logger.info("Voicemail detected! Leaving a message...")
|
||||
|
||||
# Push frames using standard Pipecat pattern
|
||||
|
||||
221
examples/foundational/49-ultravox-realtime.py
Normal file
221
examples/foundational/49-ultravox-realtime.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.ultravox.llm import OneShotInputParams, UltravoxRealtimeLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def get_secret_menu(params: FunctionCallParams):
|
||||
category = params.arguments.get("category", "both")
|
||||
logger.debug(f"Fetching secret menu with category: {category}")
|
||||
items = []
|
||||
if category in {"donuts", "both"}:
|
||||
items.append(
|
||||
{
|
||||
"name": "Butter Pecan Ice Cream (one scoop)",
|
||||
"price": "$2.99",
|
||||
}
|
||||
)
|
||||
if category in {"drinks", "both"}:
|
||||
items.append(
|
||||
{
|
||||
"name": "Banana Smoothie",
|
||||
"price": "$4.99",
|
||||
}
|
||||
)
|
||||
await params.result_callback(
|
||||
{
|
||||
"date": datetime.date.today().isoformat(),
|
||||
"items": items,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
system_prompt = f"""
|
||||
You are a drive-thru order taker for a donut shop called "Dr. Donut". Local time is currently: {datetime.datetime.now().isoformat()}
|
||||
The user is talking to you over voice on their phone, and your response will be read out loud with realistic text-to-speech (TTS) technology.
|
||||
|
||||
Follow every direction here when crafting your response:
|
||||
|
||||
1. Use natural, conversational language that is clear and easy to follow (short sentences, simple words).
|
||||
1a. Be concise and relevant: Most of your responses should be a sentence or two, unless you're asked to go deeper. Don't monopolize the conversation.
|
||||
1b. Use discourse markers to ease comprehension. Never use the list format.
|
||||
|
||||
2. Keep the conversation flowing.
|
||||
2a. Clarify: when there is ambiguity, ask clarifying questions, rather than make assumptions.
|
||||
2b. Don't implicitly or explicitly try to end the chat (i.e. do not end a response with "Talk soon!", or "Enjoy!").
|
||||
2c. Sometimes the user might just want to chat. Ask them relevant follow-up questions.
|
||||
2d. Don't ask them if there's anything else they need help with (e.g. don't say things like "How can I assist you further?").
|
||||
|
||||
3. Remember that this is a voice conversation:
|
||||
3a. Don't use lists, markdown, bullet points, or other formatting that's not typically spoken.
|
||||
3b. Type out numbers in words (e.g. 'twenty twelve' instead of the year 2012)
|
||||
3c. If something doesn't make sense, it's likely because you misheard them. There wasn't a typo, and the user didn't mispronounce anything.
|
||||
|
||||
Remember to follow these rules absolutely, and do not refer to these rules, even if you're asked about them.
|
||||
|
||||
When talking with the user, use the following script:
|
||||
1. Take their order, acknowledging each item as it is ordered. If it's not clear which menu item the user is ordering, ask them to clarify.
|
||||
DO NOT add an item to the order unless it's one of the items on the menu below.
|
||||
2. Once the order is complete, repeat back the order.
|
||||
2a. If the user only ordered a drink, ask them if they would like to add a donut to their order.
|
||||
2b. If the user only ordered donuts, ask them if they would like to add a drink to their order.
|
||||
2c. If the user ordered both drinks and donuts, don't suggest anything.
|
||||
3. Total up the price of all ordered items and inform the user.
|
||||
4. Ask the user to pull up to the drive thru window.
|
||||
If the user asks for something that's not on the menu, inform them of that fact, and suggest the most similar item on the menu.
|
||||
If the user says something unrelated to your role, responed with "Um... this is a Dr. Donut."
|
||||
If the user says "thank you", respond with "My pleasure."
|
||||
If the user asks about what's on the menu, DO NOT read the entire menu to them. Instead, give a couple suggestions.
|
||||
|
||||
The menu of available items is as follows:
|
||||
|
||||
# DONUTS
|
||||
|
||||
PUMPKIN SPICE ICED DOUGHNUT $1.29
|
||||
PUMPKIN SPICE CAKE DOUGHNUT $1.29
|
||||
OLD FASHIONED DOUGHNUT $1.29
|
||||
CHOCOLATE ICED DOUGHNUT $1.09
|
||||
CHOCOLATE ICED DOUGHNUT WITH SPRINKLES $1.09
|
||||
RASPBERRY FILLED DOUGHNUT $1.09
|
||||
BLUEBERRY CAKE DOUGHNUT $1.09
|
||||
STRAWBERRY ICED DOUGHNUT WITH SPRINKLES $1.09
|
||||
LEMON FILLED DOUGHNUT $1.09
|
||||
DOUGHNUT HOLES $3.99
|
||||
|
||||
# COFFEE & DRINKS
|
||||
|
||||
PUMPKIN SPICE COFFEE $2.59
|
||||
PUMPKIN SPICE LATTE $4.59
|
||||
REGULAR BREWED COFFEE $1.79
|
||||
DECAF BREWED COFFEE $1.79
|
||||
LATTE $3.49
|
||||
CAPPUCINO $3.49
|
||||
CARAMEL MACCHIATO $3.49
|
||||
MOCHA LATTE $3.49
|
||||
CARAMEL MOCHA LATTE $3.49
|
||||
|
||||
There is also a secret menu that changes daily. If the user asks about it, use the get_secret_menu tool to look up today's secret menu items.
|
||||
"""
|
||||
|
||||
secret_menu_function = FunctionSchema(
|
||||
name="get_secret_menu",
|
||||
description="Get today's secret menu items",
|
||||
properties={
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["donuts", "drinks", "both"],
|
||||
"description": "The category of secret menu items to retrieve. Defaults to both.",
|
||||
},
|
||||
},
|
||||
required=[],
|
||||
)
|
||||
|
||||
llm = UltravoxRealtimeLLMService(
|
||||
params=OneShotInputParams(
|
||||
api_key=os.getenv("ULTRAVOX_API_KEY"),
|
||||
system_prompt=system_prompt,
|
||||
temperature=0.3,
|
||||
max_duration=datetime.timedelta(minutes=3),
|
||||
),
|
||||
one_shot_selected_tools=ToolsSchema(standard_tools=[secret_menu_function]),
|
||||
)
|
||||
|
||||
llm.register_function("get_secret_menu", get_secret_menu)
|
||||
|
||||
# Necessary to complete the function call lifecycle in Pipecat.
|
||||
context_aggregator = LLMContextAggregatorPair(LLMContext([]))
|
||||
|
||||
# Build the pipeline
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
context_aggregator.assistant(),
|
||||
transport.output(),
|
||||
]
|
||||
)
|
||||
|
||||
# Configure the pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
# Handle client connection event
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
|
||||
# Handle client disconnection events
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
161
examples/foundational/49a-thinking-anthropic.py
Normal file
161
examples/foundational/49a-thinking-anthropic.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
params=AnthropicLLMService.InputParams(
|
||||
thinking=AnthropicLLMService.ThinkingConfig(type="enabled", budget_tokens=2048)
|
||||
),
|
||||
)
|
||||
|
||||
transcript = TranscriptProcessor(process_thoughts=True)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # Assistant transcripts (including thoughts)
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say hello briefly.",
|
||||
}
|
||||
)
|
||||
# Here are some example prompts conducive to demonstrating
|
||||
# thinking (picked from Google and Anthropic docs).
|
||||
# messages.append(
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "Analogize photosynthesis and growing up. Keep your answer concise.",
|
||||
# # "content": "Compare and contrast electric cars and hybrid cars."
|
||||
# # "content": "Are there an infinite number of prime numbers such that n mod 4 == 3?"
|
||||
# }
|
||||
# )
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role
|
||||
logger.info(f"Transcript: {timestamp}{role}: {msg.content}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
167
examples/foundational/49b-thinking-google.py
Normal file
167
examples/foundational/49b-thinking-google.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
# model="gemini-3-pro-preview", # A more powerful reasoning model, but slower
|
||||
params=GoogleLLMService.InputParams(
|
||||
thinking=GoogleLLMService.ThinkingConfig(
|
||||
# thinking_level="low", # Use this field instead of thinking_budget for Gemini 3 Pro. Defaults to "high".
|
||||
thinking_budget=-1, # Dynamic thinking
|
||||
include_thoughts=True,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
transcript = TranscriptProcessor(process_thoughts=True)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # Assistant transcripts (including thoughts)
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say hello briefly.",
|
||||
}
|
||||
)
|
||||
# Replace the above with one of these example prompts to demonstrate
|
||||
# thinking.
|
||||
# These examples come from Gemini and Anthropic docs.
|
||||
# messages.append(
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "Analogize photosynthesis and growing up. Keep your answer concise.",
|
||||
# # "content": "Compare and contrast electric cars and hybrid cars."
|
||||
# # "content": "Are there an infinite number of prime numbers such that n mod 4 == 3?"
|
||||
# }
|
||||
# )
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role
|
||||
logger.info(f"Transcript: {timestamp}{role}: {msg.content}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
185
examples/foundational/49c-thinking-functions-anthropic.py
Normal file
185
examples/foundational/49c-thinking-functions-anthropic.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def check_flight_status(params: FunctionCallParams, flight_number: str):
|
||||
"""Check the status of a flight. Returns status (e.g., "on time", "delayed") and departure time.
|
||||
|
||||
Args:
|
||||
flight_number (str): The flight number, e.g. "AA100".
|
||||
"""
|
||||
await params.result_callback({"status": "delayed", "departure_time": "14:30"})
|
||||
|
||||
|
||||
async def book_taxi(params: FunctionCallParams, time: str):
|
||||
"""Book a taxi for a given time. Returns status (e.g., "done").
|
||||
|
||||
Args:
|
||||
time (str): The time to book the taxi for, e.g. "15:00".
|
||||
"""
|
||||
await params.result_callback({"status": "done"})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
params=AnthropicLLMService.InputParams(
|
||||
thinking=AnthropicLLMService.ThinkingConfig(type="enabled", budget_tokens=2048)
|
||||
),
|
||||
)
|
||||
|
||||
llm.register_direct_function(check_flight_status)
|
||||
llm.register_direct_function(book_taxi)
|
||||
|
||||
tools = ToolsSchema(standard_tools=[check_flight_status, book_taxi])
|
||||
|
||||
transcript = TranscriptProcessor(process_thoughts=True)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # Assistant transcripts (including thoughts)
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say hello briefly.",
|
||||
}
|
||||
)
|
||||
# Here is an example prompt conducive to demonstrating thinking and
|
||||
# function calling.
|
||||
# This example comes from Gemini docs.
|
||||
# messages.append(
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "Check the status of flight AA100 and, if it's delayed, book me a taxi 2 hours before its departure time.",
|
||||
# }
|
||||
# )
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role
|
||||
logger.info(f"Transcript: {timestamp}{role}: {msg.content}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
190
examples/foundational/49d-thinking-functions-google.py
Normal file
190
examples/foundational/49d-thinking-functions-google.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def check_flight_status(params: FunctionCallParams, flight_number: str):
|
||||
"""Check the status of a flight. Returns status (e.g., "on time", "delayed") and departure time.
|
||||
|
||||
Args:
|
||||
flight_number (str): The flight number, e.g. "AA100".
|
||||
"""
|
||||
await params.result_callback({"status": "delayed", "departure_time": "14:30"})
|
||||
|
||||
|
||||
async def book_taxi(params: FunctionCallParams, time: str):
|
||||
"""Book a taxi for a given time. Returns status (e.g., "done").
|
||||
|
||||
Args:
|
||||
time (str): The time to book the taxi for, e.g. "15:00".
|
||||
"""
|
||||
await params.result_callback({"status": "done"})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
# model="gemini-3-pro-preview", # A more powerful reasoning model, but slower
|
||||
params=GoogleLLMService.InputParams(
|
||||
thinking=GoogleLLMService.ThinkingConfig(
|
||||
# thinking_level="low", # Use this field instead of thinking_budget for Gemini 3 Pro. Defaults to "high".
|
||||
thinking_budget=-1, # Dynamic thinking
|
||||
include_thoughts=True,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
llm.register_direct_function(check_flight_status)
|
||||
llm.register_direct_function(book_taxi)
|
||||
|
||||
tools = ToolsSchema(standard_tools=[check_flight_status, book_taxi])
|
||||
|
||||
transcript = TranscriptProcessor(process_thoughts=True)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # Assistant transcripts (including thoughts)
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say hello briefly.",
|
||||
}
|
||||
)
|
||||
# Replace the above with one of these example prompts to demonstrate
|
||||
# thinking and function calling.
|
||||
# This example comes from Gemini docs.
|
||||
# messages.append(
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "Check the status of flight AA100 and, if it's delayed, book me a taxi 2 hours before its departure time.",
|
||||
# }
|
||||
# )
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role
|
||||
logger.info(f"Transcript: {timestamp}{role}: {msg.content}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -62,7 +62,8 @@ fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.51.0,<2", "pipecat-ai[websockets-base]" ]
|
||||
gradium = [ "pipecat-ai[websockets-base]" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
@@ -108,7 +109,7 @@ strands = [ "strands-agents>=1.9.1,<2" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
|
||||
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
|
||||
ultravox = [ "pipecat-ai[websockets-base]" ]
|
||||
webrtc = [ "aiortc>=1.13.0,<2", "opencv-python>=4.11.0.86,<5" ]
|
||||
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.122.0" ]
|
||||
websockets-base = [ "websockets>=13.1,<16.0" ]
|
||||
@@ -129,6 +130,7 @@ dev = [
|
||||
"setuptools~=78.1.1",
|
||||
"setuptools_scm~=8.3.1",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"towncrier~=25.8.0",
|
||||
]
|
||||
|
||||
docs = [
|
||||
@@ -159,7 +161,7 @@ where = ["src"]
|
||||
"src/pipecat/audio/dtmf/dtmf-star.wav",
|
||||
]
|
||||
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.1-cpu.onnx"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
@@ -206,3 +208,45 @@ convention = "google"
|
||||
command_line = "--module pytest"
|
||||
source = ["src"]
|
||||
omit = ["*/tests/*"]
|
||||
|
||||
[tool.towncrier]
|
||||
package = "pipecat"
|
||||
package_dir = "src"
|
||||
filename = "CHANGELOG.md"
|
||||
directory = "changelog"
|
||||
start_string = "<!-- towncrier release notes start -->\n"
|
||||
template = "changelog/_template.md.j2"
|
||||
title_format = "## [{version}] - {project_date}"
|
||||
issue_format = "[#{issue}](https://github.com/pipecat-ai/pipecat/pull/{issue})"
|
||||
underlines = ["", "", ""]
|
||||
wrap = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "added"
|
||||
name = "Added"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "changed"
|
||||
name = "Changed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "deprecated"
|
||||
name = "Deprecated"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "removed"
|
||||
name = "Removed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "fixed"
|
||||
name = "Fixed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "security"
|
||||
name = "Security"
|
||||
showcontent = true
|
||||
|
||||
@@ -31,7 +31,13 @@ from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import EndTaskFrame, LLMRunFrame, OutputImageRawFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
EndTaskFrame,
|
||||
LLMRunFrame,
|
||||
OutputImageRawFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -50,6 +56,7 @@ SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
PIPELINE_IDLE_TIMEOUT_SECS = 60
|
||||
EVAL_TIMEOUT_SECS = 120
|
||||
EVAL_RESULT_TIMEOUT_SECS = 10
|
||||
|
||||
EvalPrompt = str | Tuple[str, ImageFile]
|
||||
|
||||
@@ -78,7 +85,7 @@ class EvalRunner:
|
||||
self._log_level = log_level
|
||||
self._total_success = 0
|
||||
self._tests: List[EvalResult] = []
|
||||
self._queue = asyncio.Queue()
|
||||
self._result_future: Optional[asyncio.Future[bool]] = None
|
||||
|
||||
# We to save runner files.
|
||||
name = name or f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
@@ -88,16 +95,16 @@ class EvalRunner:
|
||||
os.makedirs(self._logs_dir, exist_ok=True)
|
||||
os.makedirs(self._recordings_dir, exist_ok=True)
|
||||
|
||||
async def assert_eval(self, params: FunctionCallParams):
|
||||
async def function_assert_eval(self, params: FunctionCallParams):
|
||||
result = params.arguments["result"]
|
||||
reasoning = params.arguments["reasoning"]
|
||||
logger.debug(f"🧠 EVAL REASONING(result: {result}): {reasoning}")
|
||||
await self._queue.put(result)
|
||||
await params.result_callback(None)
|
||||
await params.llm.push_frame(EndTaskFrame(), FrameDirection.UPSTREAM)
|
||||
await params.llm.push_frame(EndTaskFrame(reason=result), FrameDirection.UPSTREAM)
|
||||
|
||||
async def assert_eval_false(self):
|
||||
await self._queue.put(False)
|
||||
async def assert_eval(self, result: bool):
|
||||
if self._result_future:
|
||||
self._result_future.set_result(result)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
@@ -117,6 +124,9 @@ class EvalRunner:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Create a future to store the eval result.
|
||||
self._result_future = asyncio.get_running_loop().create_future()
|
||||
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(run_example_pipeline(script_path, eval_config)),
|
||||
@@ -136,8 +146,10 @@ class EvalRunner:
|
||||
logger.error(f"ERROR: Unable to run {example_file}: {e}")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
# Wait for the future to resolve.
|
||||
result = await asyncio.wait_for(self._result_future, timeout=EVAL_RESULT_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"ERROR: Timeout waiting for eval result.")
|
||||
result = False
|
||||
|
||||
if result:
|
||||
@@ -244,19 +256,25 @@ async def run_eval_pipeline(
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
llm.register_function("eval_function", eval_runner.assert_eval)
|
||||
llm.register_function("eval_function", eval_runner.function_assert_eval)
|
||||
|
||||
eval_function = FunctionSchema(
|
||||
name="eval_function",
|
||||
description="Called when the user answers a question.",
|
||||
description=(
|
||||
"Determines whether the user's response satisfies the evaluation "
|
||||
"criteria defined for the current prompt or interaction."
|
||||
),
|
||||
properties={
|
||||
"result": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the answer is correct or not",
|
||||
"description": "Whether the user's response meets the evaluation criteria.",
|
||||
},
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Why the answer was considered correct or invalid",
|
||||
"description": (
|
||||
"A concise explanation of how the user's response did or did "
|
||||
"not satisfy the evaluation criteria."
|
||||
),
|
||||
},
|
||||
},
|
||||
required=["result", "reasoning"],
|
||||
@@ -278,9 +296,9 @@ async def run_eval_pipeline(
|
||||
"Ignore greetings, comments, non-answers, or requests for clarification."
|
||||
)
|
||||
if eval_config.eval_speaks_first:
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. Numerical word answers are allowed. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
|
||||
else:
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. First, ask one question: {example_prompt}. {common_system_prompt}"
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. Numerical word answers are allowed. First, ask one question: {example_prompt}. {common_system_prompt}"
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -346,9 +364,12 @@ async def run_eval_pipeline(
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
@task.event_handler("on_idle_timeout")
|
||||
async def on_pipeline_idle_timeout(task):
|
||||
await eval_runner.assert_eval_false()
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame):
|
||||
if isinstance(frame, EndFrame):
|
||||
await eval_runner.assert_eval(frame.reason)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await eval_runner.assert_eval(False)
|
||||
|
||||
# TODO(aleix): We should handle SIGINT and SIGTERM so we can cancel both the
|
||||
# eval and the example.
|
||||
|
||||
@@ -30,13 +30,13 @@ EVAL_SIMPLE_MATH = EvalConfig(
|
||||
)
|
||||
|
||||
EVAL_WEATHER = EvalConfig(
|
||||
prompt="What's the weather in San Francisco (in farhenheit or celsius)?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees (in farhenheit or celsius).",
|
||||
prompt="What's the weather in San Francisco? Temperature should be in fahrenheits.",
|
||||
eval="The user talks about the weather in San Francisco, including the degrees.",
|
||||
)
|
||||
|
||||
EVAL_ONLINE_SEARCH = EvalConfig(
|
||||
prompt="What's the date right now in London?",
|
||||
eval=f"The user says today is {datetime.now(timezone.utc).strftime('%B %d, %Y')} in London.",
|
||||
prompt="What's the current date in UTC?",
|
||||
eval=f"Current date in UTC is {datetime.now(timezone.utc).strftime('%A, %B %d, %Y')}.",
|
||||
)
|
||||
|
||||
EVAL_SWITCH_LANGUAGE = EvalConfig(
|
||||
@@ -64,16 +64,21 @@ def EVAL_VISION_IMAGE(*, eval_speaks_first: bool = False):
|
||||
|
||||
EVAL_VOICEMAIL = EvalConfig(
|
||||
prompt="Please leave a message.",
|
||||
eval="The user leaves a voicemail message.",
|
||||
eval="The user provides a reasonable voicemail message.",
|
||||
eval_speaks_first=True,
|
||||
)
|
||||
|
||||
EVAL_CONVERSATION = EvalConfig(
|
||||
prompt="Hello, this is Mark.",
|
||||
eval="The user acknowledges the greeting.",
|
||||
eval="The user provides any reasonable conversational response to the greeting.",
|
||||
eval_speaks_first=True,
|
||||
)
|
||||
|
||||
EVAL_FLIGHT_STATUS = EvalConfig(
|
||||
prompt="Check the status of flight AA100.",
|
||||
eval="The user says something about the status of flight AA100, such as whether it's on time or delayed.",
|
||||
)
|
||||
|
||||
|
||||
TESTS_07 = [
|
||||
# 07 series
|
||||
@@ -81,6 +86,7 @@ TESTS_07 = [
|
||||
("07-interruptible-cartesia-http.py", EVAL_SIMPLE_MATH),
|
||||
("07a-interruptible-speechmatics.py", EVAL_SIMPLE_MATH),
|
||||
("07aa-interruptible-soniox.py", EVAL_SIMPLE_MATH),
|
||||
("07ab-interruptible-inworld.py", EVAL_SIMPLE_MATH),
|
||||
("07ab-interruptible-inworld-http.py", EVAL_SIMPLE_MATH),
|
||||
("07ac-interruptible-asyncai.py", EVAL_SIMPLE_MATH),
|
||||
("07ac-interruptible-asyncai-http.py", EVAL_SIMPLE_MATH),
|
||||
@@ -116,8 +122,6 @@ TESTS_07 = [
|
||||
# ("07i-interruptible-xtts.py", EVAL_SIMPLE_MATH),
|
||||
# Needs a Krisp license.
|
||||
# ("07p-interruptible-krisp.py", EVAL_SIMPLE_MATH),
|
||||
# Needs GPU resources.
|
||||
# ("07u-interruptible-ultravox.py", EVAL_SIMPLE_MATH),
|
||||
]
|
||||
|
||||
TESTS_12 = [
|
||||
@@ -204,6 +208,13 @@ TESTS_44 = [
|
||||
("44-voicemail-detection.py", EVAL_CONVERSATION),
|
||||
]
|
||||
|
||||
TESTS_49 = [
|
||||
("49a-thinking-anthropic.py", EVAL_SIMPLE_MATH),
|
||||
("49b-thinking-google.py", EVAL_SIMPLE_MATH),
|
||||
("49c-thinking-functions-anthropic.py", EVAL_FLIGHT_STATUS),
|
||||
("49d-thinking-functions-google.py", EVAL_FLIGHT_STATUS),
|
||||
]
|
||||
|
||||
TESTS = [
|
||||
*TESTS_07,
|
||||
*TESTS_12,
|
||||
@@ -216,6 +227,7 @@ TESTS = [
|
||||
*TESTS_40,
|
||||
*TESTS_43,
|
||||
*TESTS_44,
|
||||
*TESTS_49,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -5,14 +5,20 @@
|
||||
#
|
||||
|
||||
import sys
|
||||
from importlib.metadata import version
|
||||
from importlib.metadata import version as lib_version
|
||||
|
||||
from loguru import logger
|
||||
|
||||
__version__ = version("pipecat-ai")
|
||||
__version__ = lib_version("pipecat-ai")
|
||||
|
||||
logger.info(f"ᓚᘏᗢ Pipecat {__version__} (Python {sys.version}) ᓚᘏᗢ")
|
||||
|
||||
|
||||
def version() -> str:
|
||||
"""Returns the Pipecat version."""
|
||||
return __version__
|
||||
|
||||
|
||||
# We replace `asyncio.wait_for()` for `wait_for2.wait_for()` for Python < 3.12.
|
||||
#
|
||||
# In Python 3.12, `asyncio.wait_for()` is implemented in terms of
|
||||
|
||||
@@ -94,6 +94,8 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image":
|
||||
item["source"]["data"] = "..."
|
||||
if item["type"] == "thinking" and item.get("signature"):
|
||||
item["signature"] = "..."
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
@@ -165,9 +167,44 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
|
||||
def _from_universal_context_message(self, message: LLMContextMessage) -> MessageParam:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return copy.deepcopy(message.message)
|
||||
return self._from_anthropic_specific_message(message)
|
||||
return self._from_standard_message(message)
|
||||
|
||||
def _from_anthropic_specific_message(self, message: LLMSpecificMessage) -> MessageParam:
|
||||
"""Convert LLMSpecificMessage to Anthropic format.
|
||||
|
||||
Anthropic-specific messages may either be special thought messages that
|
||||
need to be handled in a special way, or messages already in Anthropic
|
||||
format.
|
||||
|
||||
Args:
|
||||
message: Anthropic-specific message.
|
||||
"""
|
||||
# Handle special case of thought messages.
|
||||
# These can be converted to standalone "assistant" messages; later
|
||||
# these thinking messages will be properly merged into the assistant
|
||||
# response messages before the context is sent to Anthropic for the
|
||||
# next turn.
|
||||
if (
|
||||
isinstance(message.message, dict)
|
||||
and message.message.get("type") == "thought"
|
||||
and (text := message.message.get("text"))
|
||||
and (signature := message.message.get("signature"))
|
||||
):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": text,
|
||||
"signature": signature,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Fall back to assuming that the message is already in Anthropic format
|
||||
return copy.deepcopy(message.message)
|
||||
|
||||
def _from_standard_message(self, message: LLMStandardMessage) -> MessageParam:
|
||||
"""Convert standard universal context message to Anthropic format.
|
||||
|
||||
@@ -246,11 +283,14 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:"):
|
||||
# Extract MIME type from data URL (format: "data:image/jpeg;base64,...")
|
||||
url = item["image_url"]["url"]
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
item["type"] = "image"
|
||||
item["source"] = {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": item["image_url"]["url"].split(",")[1],
|
||||
"media_type": mime_type,
|
||||
"data": url.split(",")[1],
|
||||
}
|
||||
del item["image_url"]
|
||||
elif item["image_url"]["url"].startswith("http"):
|
||||
|
||||
@@ -257,14 +257,15 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:"):
|
||||
# Extract format from data URL (format: "data:image/jpeg;base64,...")
|
||||
url = item["image_url"]["url"]
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
# Bedrock expects format like "jpeg", "png" etc., not "image/jpeg"
|
||||
image_format = mime_type.split("/")[1]
|
||||
new_item = {
|
||||
"image": {
|
||||
"format": "jpeg",
|
||||
"source": {
|
||||
"bytes": base64.b64decode(
|
||||
item["image_url"]["url"].split(",")[1]
|
||||
)
|
||||
},
|
||||
"format": image_format,
|
||||
"source": {"bytes": base64.b64decode(url.split(",")[1])},
|
||||
}
|
||||
}
|
||||
new_content.append(new_item)
|
||||
|
||||
@@ -151,6 +151,8 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
for part in obj["parts"]:
|
||||
if "inline_data" in part:
|
||||
part["inline_data"]["data"] = "..."
|
||||
if "thought_signature" in part:
|
||||
part["thought_signature"] = "..."
|
||||
except Exception as e:
|
||||
logger.debug(f"Error: {e}")
|
||||
messages_for_logging.append(obj)
|
||||
@@ -209,16 +211,37 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
system_instruction = None
|
||||
messages = []
|
||||
tool_call_id_to_name_mapping = {}
|
||||
thought_signature_dicts = []
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
# Process each message, converting to Google format as needed
|
||||
for message in universal_context_messages:
|
||||
result = self._from_universal_context_message(
|
||||
# We have a Google-specific message; this may either be a
|
||||
# thought-signature-containing message that we need to handle in a
|
||||
# special way, or a message already in Google format that we can
|
||||
# use directly
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
if (
|
||||
isinstance(message.message, dict)
|
||||
and message.message.get("type") == "thought_signature"
|
||||
):
|
||||
thought_signature_dicts.append(message.message)
|
||||
continue
|
||||
|
||||
# Fall back to assuming that the message is already in Google
|
||||
# format
|
||||
messages.append(message.message)
|
||||
continue
|
||||
|
||||
# We have a standard universal context message; convert it to
|
||||
# Google format
|
||||
result = self._from_standard_message(
|
||||
message,
|
||||
params=self.MessageConversionParams(
|
||||
already_have_system_instruction=bool(system_instruction),
|
||||
tool_call_id_to_name_mapping=tool_call_id_to_name_mapping,
|
||||
),
|
||||
)
|
||||
|
||||
# Each result is either a Content or a system instruction
|
||||
if result.content:
|
||||
messages.append(result.content)
|
||||
@@ -229,6 +252,9 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
if result.tool_call_id_to_name_mapping:
|
||||
tool_call_id_to_name_mapping.update(result.tool_call_id_to_name_mapping)
|
||||
|
||||
# Apply thought signatures to the corresponding messages
|
||||
self._apply_thought_signatures_to_messages(thought_signature_dicts, messages)
|
||||
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
len(msg.parts) == 1
|
||||
@@ -247,13 +273,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system_instruction=system_instruction)
|
||||
|
||||
def _from_universal_context_message(
|
||||
self, message: LLMContextMessage, *, params: MessageConversionParams
|
||||
) -> MessageConversionResult:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return self.MessageConversionResult(content=message.message)
|
||||
return self._from_standard_message(message, params=params)
|
||||
|
||||
def _from_standard_message(
|
||||
self, message: LLMStandardMessage, *, params: MessageConversionParams
|
||||
) -> MessageConversionResult:
|
||||
@@ -380,11 +399,14 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
if c["type"] == "text":
|
||||
parts.append(Part(text=c["text"]))
|
||||
elif c["type"] == "image_url" and c["image_url"]["url"].startswith("data:"):
|
||||
# Extract MIME type from data URL (format: "data:image/jpeg;base64,...")
|
||||
url = c["image_url"]["url"]
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="image/jpeg",
|
||||
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
|
||||
mime_type=mime_type,
|
||||
data=base64.b64decode(url.split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -410,3 +432,139 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
content=Content(role=role, parts=parts),
|
||||
tool_call_id_to_name_mapping=tool_call_id_to_name_mapping,
|
||||
)
|
||||
|
||||
def _apply_thought_signatures_to_messages(
|
||||
self, thought_signature_dicts: List[dict], messages: List[Content]
|
||||
) -> None:
|
||||
"""Apply thought signatures to corresponding assistant messages.
|
||||
|
||||
See GoogleLLMService for more details about thought signatures.
|
||||
|
||||
Args:
|
||||
thought_signature_dicts: A list of dicts containing:
|
||||
- "signature": a thought signature
|
||||
- "bookmark": a bookmark to identify the message part to apply the signature to.
|
||||
The bookmark may contain one of:
|
||||
- "function_call" (a function call ID string)
|
||||
- "text" (a text string)
|
||||
- "inline_data" (a Blob)
|
||||
The list of thought signature dicts is in order.
|
||||
messages: List of messages to apply the thought signatures to.
|
||||
"""
|
||||
if not thought_signature_dicts:
|
||||
return
|
||||
|
||||
# For debugging, print out thought signatures and their bookmarks
|
||||
logger.debug(f"Thought signatures to apply: {len(thought_signature_dicts)}")
|
||||
for ts in thought_signature_dicts:
|
||||
bookmark = ts.get("bookmark")
|
||||
if bookmark.get("function_call"):
|
||||
logger.trace(f" - To function call: {bookmark['function_call']}")
|
||||
elif bookmark.get("text"):
|
||||
text = bookmark["text"]
|
||||
log_display_text = f"{text[:50]}..." if len(text) > 50 else text
|
||||
logger.trace(f" - To text: {log_display_text}")
|
||||
elif bookmark.get("inline_data"):
|
||||
logger.trace(f" - To inline data")
|
||||
|
||||
# Get all assistant messages
|
||||
assistant_messages = [
|
||||
message
|
||||
for message in messages
|
||||
if isinstance(message, Content) and message.role == "model"
|
||||
]
|
||||
|
||||
# Apply thought signatures to the corresponding assistant messages.
|
||||
# Thought signatures are already in message order.
|
||||
thought_signatures_applied = 0
|
||||
message_start_index = 0 # Track where to start searching for the next matching message.
|
||||
for thought_signature_dict in thought_signature_dicts:
|
||||
signature = thought_signature_dict.get("signature")
|
||||
bookmark = thought_signature_dict.get("bookmark")
|
||||
if not signature or not bookmark:
|
||||
continue
|
||||
|
||||
# Search through remaining assistant messages for a match
|
||||
for i in range(message_start_index, len(assistant_messages)):
|
||||
message = assistant_messages[i]
|
||||
if not message.parts:
|
||||
continue
|
||||
|
||||
# We're assuming that the thought signature always applies to the last part
|
||||
last_part = message.parts[-1]
|
||||
|
||||
# If the bookmark matches the part...
|
||||
if self._thought_signature_bookmark_matches_part(bookmark, last_part):
|
||||
# Apply the thought signature
|
||||
last_part.thought_signature = signature
|
||||
thought_signatures_applied += 1
|
||||
|
||||
# Update the start index and stop searching for a match
|
||||
message_start_index = i + 1
|
||||
break
|
||||
|
||||
# For debugging, print out how many thought signatures were applied
|
||||
logger.debug(f"Applied {thought_signatures_applied} thought signatures.")
|
||||
|
||||
def _thought_signature_bookmark_matches_part(self, bookmark: dict, part: Part) -> bool:
|
||||
if function_call_bookmark := bookmark.get("function_call"):
|
||||
return self._thought_signature_function_call_bookmark_matches_part(
|
||||
function_call_bookmark, part
|
||||
)
|
||||
elif text_bookmark := bookmark.get("text"):
|
||||
return self._thought_signature_text_bookmark_matches_part(text_bookmark, part)
|
||||
elif inline_data := bookmark.get("inline_data"):
|
||||
return self._thought_signature_inline_data_bookmark_matches_part(inline_data, part)
|
||||
else:
|
||||
logger.warning(f"Unknown thought signature bookmark type: {bookmark}")
|
||||
|
||||
return False
|
||||
|
||||
def _thought_signature_function_call_bookmark_matches_part(
|
||||
self, bookmark_function_call_id: str, part: Part
|
||||
) -> bool:
|
||||
if (
|
||||
hasattr(part, "function_call")
|
||||
and part.function_call
|
||||
and part.function_call.id == bookmark_function_call_id
|
||||
):
|
||||
logger.trace(f"Thought signature function call match: {bookmark_function_call_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _thought_signature_text_bookmark_matches_part(self, bookmark_text: str, part: Part) -> bool:
|
||||
if hasattr(part, "text") and part.text:
|
||||
# Normalize whitespace for comparison
|
||||
bookmark_text = " ".join(bookmark_text.split())
|
||||
part_text = " ".join(part.text.split())
|
||||
# Check that either:
|
||||
# - the part text is the same as the bookmark text
|
||||
# - a prefix of the bookmark text (in case the part text was truncated due to interruption)
|
||||
# - the bookmark text is a prefix of the part text (in case the bookmark represents just first chunk of multi-chunk text)
|
||||
if (
|
||||
part_text == bookmark_text
|
||||
or bookmark_text.startswith(part_text)
|
||||
or part_text.startswith(bookmark_text)
|
||||
):
|
||||
log_display_text = f"{part.text[:50]}..." if len(part.text) > 50 else part.text
|
||||
logger.trace(f"Thought signature text match: {log_display_text}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _thought_signature_inline_data_bookmark_matches_part(
|
||||
self, bookmark_inline_data: Blob, part: Part
|
||||
) -> bool:
|
||||
if (
|
||||
hasattr(part, "inline_data")
|
||||
and part.inline_data
|
||||
# Comparing length should be good enough for matching inline data,
|
||||
# especially since we're already matching thought signatures in
|
||||
# strict message order. Comparing actual data is expensive.
|
||||
and len(part.inline_data.data) == len(bookmark_inline_data.data)
|
||||
):
|
||||
logger.trace(f"Thought signature inline data match")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -28,7 +28,6 @@ from pipecat.metrics.metrics import MetricsData, SmartTurnMetricsData
|
||||
STOP_SECS = 3
|
||||
PRE_SPEECH_MS = 0
|
||||
MAX_DURATION_SECONDS = 8 # Max allowed segment duration
|
||||
USE_ONLY_LAST_VAD_SEGMENT = True
|
||||
|
||||
|
||||
class SmartTurnParams(BaseTurnParams):
|
||||
@@ -43,8 +42,6 @@ class SmartTurnParams(BaseTurnParams):
|
||||
stop_secs: float = STOP_SECS
|
||||
pre_speech_ms: float = PRE_SPEECH_MS
|
||||
max_duration_secs: float = MAX_DURATION_SECONDS
|
||||
# not exposing this for now yet until the model can handle it.
|
||||
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
|
||||
|
||||
|
||||
class SmartTurnTimeoutException(Exception):
|
||||
@@ -160,7 +157,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
state, result = await loop.run_in_executor(
|
||||
self._executor, self._process_speech_segment, self._audio_buffer
|
||||
)
|
||||
if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT:
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
self._clear(state)
|
||||
logger.debug(f"End of Turn result: {state}")
|
||||
return state, result
|
||||
|
||||
Binary file not shown.
@@ -14,6 +14,7 @@ Note: To learn more about the smart-turn model, visit:
|
||||
- https://github.com/pipecat-ai/smart-turn
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -26,6 +27,10 @@ class FalSmartTurnAnalyzer(HttpSmartTurnAnalyzer):
|
||||
|
||||
Extends HttpSmartTurnAnalyzer to provide integration with Fal.ai's
|
||||
smart turn detection API endpoint with proper authentication.
|
||||
|
||||
.. deprecated:: 0.98.0
|
||||
FalSmartTurnAnalyzer is deprecated and will be removed in a future version.
|
||||
Use LocalSmartTurnAnalyzerV3 instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -48,3 +53,12 @@ class FalSmartTurnAnalyzer(HttpSmartTurnAnalyzer):
|
||||
if api_key:
|
||||
headers = {"Authorization": f"Key {api_key}"}
|
||||
super().__init__(url=url, aiohttp_session=aiohttp_session, headers=headers, **kwargs)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"FalSmartTurnAnalyzer is deprecated and will be removed in a future version. "
|
||||
"Use LocalSmartTurnAnalyzerV3 instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides a smart turn analyzer that uses PyTorch models for
|
||||
local end-of-turn detection without requiring network connectivity.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
@@ -34,6 +35,10 @@ class LocalSmartTurnAnalyzer(BaseSmartTurn):
|
||||
Provides end-of-turn detection using locally-stored PyTorch models,
|
||||
enabling offline operation without network dependencies. Uses
|
||||
Wav2Vec2-BERT architecture for audio sequence classification.
|
||||
|
||||
.. deprecated:: 0.98.0
|
||||
LocalSmartTurnAnalyzer is deprecated and will be removed in a future version.
|
||||
Use LocalSmartTurnAnalyzerV3 instead.
|
||||
"""
|
||||
|
||||
def __init__(self, *, smart_turn_model_path: str, **kwargs):
|
||||
@@ -46,6 +51,15 @@ class LocalSmartTurnAnalyzer(BaseSmartTurn):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"LocalSmartTurnAnalyzer is deprecated and will be removed in a future version. "
|
||||
"Use LocalSmartTurnAnalyzerV3 instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Define the path to the pretrained model on Hugging Face
|
||||
smart_turn_model_path = "pipecat-ai/smart-turn"
|
||||
|
||||
@@ -42,17 +42,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
Args:
|
||||
smart_turn_model_path: Path to the ONNX model file. If this is not
|
||||
set, the bundled smart-turn-v3.0 model will be used.
|
||||
set, the bundled smart-turn-v3.1-cpu model will be used.
|
||||
cpu_count: The number of CPUs to use for inference. Defaults to 1.
|
||||
**kwargs: Additional arguments passed to BaseSmartTurn.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
logger.debug("Loading Local Smart Turn v3 model...")
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.0.onnx"
|
||||
model_name = "smart-turn-v3.1-cpu.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
@@ -70,6 +68,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
impresources.files(package_path).joinpath(model_name)
|
||||
)
|
||||
|
||||
logger.debug(f"Loading Local Smart Turn v3.x model from {smart_turn_model_path}...")
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
@@ -79,7 +79,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3")
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
@@ -252,7 +252,8 @@ class ClassificationProcessor(FrameProcessor):
|
||||
self._voicemail_notifier = voicemail_notifier
|
||||
self._voicemail_response_delay = voicemail_response_delay
|
||||
|
||||
# Register the voicemail detected event
|
||||
# Register the conversation and voicemail detected events
|
||||
self._register_event_handler("on_conversation_detected")
|
||||
self._register_event_handler("on_voicemail_detected")
|
||||
|
||||
# Aggregation state for collecting complete LLM responses
|
||||
@@ -350,6 +351,7 @@ class ClassificationProcessor(FrameProcessor):
|
||||
logger.info(f"{self}: CONVERSATION detected")
|
||||
await self._gate_notifier.notify() # Close the classifier gate
|
||||
await self._conversation_notifier.notify() # Release buffered TTS frames
|
||||
await self._call_event_handler("on_conversation_detected")
|
||||
|
||||
elif "VOICEMAIL" in response:
|
||||
# Voicemail detected - trigger voicemail handling
|
||||
@@ -539,6 +541,9 @@ class VoicemailDetector(ParallelPipeline):
|
||||
custom_prompt = "Your custom classification logic here. " + VoicemailDetector.CLASSIFIER_RESPONSE_INSTRUCTION
|
||||
|
||||
Events:
|
||||
on_conversation_detected: Triggered when a human conversation is detected. The
|
||||
event handler receives one argument: the ClassificationProcessor instance
|
||||
which can be used to push frames.
|
||||
on_voicemail_detected: Triggered when voicemail is detected after the configured
|
||||
delay. The event handler receives one argument: the ClassificationProcessor
|
||||
instance which can be used to push frames.
|
||||
@@ -701,7 +706,7 @@ VOICEMAIL SYSTEM (respond "VOICEMAIL"):
|
||||
event_name: The name of the event to handle.
|
||||
handler: The function to call when the event occurs.
|
||||
"""
|
||||
if event_name == "on_voicemail_detected":
|
||||
if event_name in ("on_conversation_detected", "on_voicemail_detected"):
|
||||
self._classification_processor.add_event_handler(event_name, handler)
|
||||
else:
|
||||
super().add_event_handler(event_name, handler)
|
||||
|
||||
@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
@@ -186,6 +186,20 @@ class ControlFrame(Frame):
|
||||
#
|
||||
|
||||
|
||||
@dataclass
|
||||
class UninterruptibleFrame:
|
||||
"""A marker for data or control frames that must not be interrupted.
|
||||
|
||||
Frames with this mixin are still ordered normally, but unlike other frames,
|
||||
they are preserved during interruptions: they remain in internal queues and
|
||||
any task processing them will not be cancelled. This ensures the frame is
|
||||
always delivered and processed to completion.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioRawFrame:
|
||||
"""A frame containing a chunk of raw audio.
|
||||
@@ -213,7 +227,7 @@ class ImageRawFrame:
|
||||
Parameters:
|
||||
image: Raw image bytes.
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
format: Image format (e.g., 'JPEG', 'PNG').
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
"""
|
||||
|
||||
image: bytes
|
||||
@@ -386,6 +400,13 @@ class AggregatedTextFrame(TextFrame):
|
||||
aggregated_by: AggregationType | str
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionTextFrame(LLMTextFrame):
|
||||
"""Text frame generated by vision services."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSTextFrame(AggregatedTextFrame):
|
||||
"""Text frame generated by Text-to-Speech services."""
|
||||
@@ -498,6 +519,15 @@ class TranscriptionMessage:
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtTranscriptionMessage:
|
||||
"""An LLM thought message in a conversation transcript."""
|
||||
|
||||
role: Literal["assistant"] = field(default="assistant", init=False)
|
||||
content: str
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionUpdateFrame(DataFrame):
|
||||
"""Frame containing new messages added to conversation transcript.
|
||||
@@ -542,7 +572,7 @@ class TranscriptionUpdateFrame(DataFrame):
|
||||
messages: List of new transcript messages that were added.
|
||||
"""
|
||||
|
||||
messages: List[TranscriptionMessage]
|
||||
messages: List[TranscriptionMessage | ThoughtTranscriptionMessage]
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
@@ -563,6 +593,75 @@ class LLMContextFrame(Frame):
|
||||
context: "LLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMThoughtStartFrame(ControlFrame):
|
||||
"""Frame indicating the start of an LLM thought.
|
||||
|
||||
Parameters:
|
||||
append_to_context: Whether the thought should be appended to the LLM context.
|
||||
If it is appended, the `llm` field is required, since it will be
|
||||
appended as an `LLMSpecificMessage`.
|
||||
llm: Optional identifier of the LLM provider for LLM-specific handling.
|
||||
Only required if `append_to_context` is True, as the thought is
|
||||
appended to context as an `LLMSpecificMessage`.
|
||||
"""
|
||||
|
||||
append_to_context: bool = False
|
||||
llm: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.append_to_context and self.llm is None:
|
||||
raise ValueError("When append_to_context is True, llm must be set")
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
return (
|
||||
f"{self.name}(pts: {pts}, append_to_context: {self.append_to_context}, llm: {self.llm})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMThoughtTextFrame(DataFrame):
|
||||
"""Frame containing the text (or text chunk) of an LLM thought.
|
||||
|
||||
Note that despite this containing text, it is a DataFrame and not a
|
||||
TextFrame, to avoid most typical text processing, such as TTS.
|
||||
|
||||
Parameters:
|
||||
text: The text (or text chunk) of the thought.
|
||||
"""
|
||||
|
||||
text: str
|
||||
includes_inter_frame_spaces: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Assume that thought text chunks include all necessary spaces
|
||||
self.includes_inter_frame_spaces = True
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
return f"{self.name}(pts: {pts}, thought text: {self.text})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMThoughtEndFrame(ControlFrame):
|
||||
"""Frame indicating the end of an LLM thought.
|
||||
|
||||
Parameters:
|
||||
signature: Optional signature associated with the thought.
|
||||
This is used by Anthropic, which includes a signature at the end of
|
||||
each thought.
|
||||
"""
|
||||
|
||||
signature: Any = None
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
return f"{self.name}(pts: {pts}, signature: {self.signature})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesFrame(DataFrame):
|
||||
"""Frame containing LLM messages for chat completion.
|
||||
@@ -696,6 +795,44 @@ class LLMConfigureOutputFrame(DataFrame):
|
||||
skip_tts: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallResultProperties:
|
||||
"""Properties for configuring function call result behavior.
|
||||
|
||||
Parameters:
|
||||
run_llm: Whether to run the LLM after receiving this result.
|
||||
on_context_updated: Callback to execute when context is updated.
|
||||
"""
|
||||
|
||||
run_llm: Optional[bool] = None
|
||||
on_context_updated: Optional[Callable[[], Awaitable[None]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallResultFrame(DataFrame, UninterruptibleFrame):
|
||||
"""Frame containing the result of an LLM function call.
|
||||
|
||||
This is an uninterruptible frame because once a result is generated we
|
||||
always want to update the context.
|
||||
|
||||
Parameters:
|
||||
function_name: Name of the function that was executed.
|
||||
tool_call_id: Unique identifier for the function call.
|
||||
arguments: Arguments that were passed to the function.
|
||||
result: The result returned by the function.
|
||||
run_llm: Whether to run the LLM after this result.
|
||||
properties: Additional properties for result handling.
|
||||
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Any
|
||||
result: Any
|
||||
run_llm: Optional[bool] = None
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSSpeakFrame(DataFrame):
|
||||
"""Frame containing text that should be spoken by TTS.
|
||||
@@ -817,7 +954,7 @@ class CancelFrame(SystemFrame):
|
||||
reason: Optional reason for pushing a cancel frame.
|
||||
"""
|
||||
|
||||
reason: Optional[str] = None
|
||||
reason: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
@@ -1089,23 +1226,6 @@ class FunctionCallsStartedFrame(SystemFrame):
|
||||
function_calls: Sequence[FunctionCallFromLLM]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallInProgressFrame(SystemFrame):
|
||||
"""Frame signaling that a function call is currently executing.
|
||||
|
||||
Parameters:
|
||||
function_name: Name of the function being executed.
|
||||
tool_call_id: Unique identifier for this function call.
|
||||
arguments: Arguments passed to the function.
|
||||
cancel_on_interruption: Whether to cancel this call if interrupted.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Any
|
||||
cancel_on_interruption: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallCancelFrame(SystemFrame):
|
||||
"""Frame signaling that a function call has been cancelled.
|
||||
@@ -1119,40 +1239,6 @@ class FunctionCallCancelFrame(SystemFrame):
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallResultProperties:
|
||||
"""Properties for configuring function call result behavior.
|
||||
|
||||
Parameters:
|
||||
run_llm: Whether to run the LLM after receiving this result.
|
||||
on_context_updated: Callback to execute when context is updated.
|
||||
"""
|
||||
|
||||
run_llm: Optional[bool] = None
|
||||
on_context_updated: Optional[Callable[[], Awaitable[None]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallResultFrame(SystemFrame):
|
||||
"""Frame containing the result of an LLM function call.
|
||||
|
||||
Parameters:
|
||||
function_name: Name of the function that was executed.
|
||||
tool_call_id: Unique identifier for the function call.
|
||||
arguments: Arguments that were passed to the function.
|
||||
result: The result returned by the function.
|
||||
run_llm: Whether to run the LLM after this result.
|
||||
properties: Additional properties for result handling.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Any
|
||||
result: Any
|
||||
run_llm: Optional[bool] = None
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTMuteFrame(SystemFrame):
|
||||
"""Frame to mute/unmute the Speech-to-Text service.
|
||||
@@ -1387,6 +1473,23 @@ class UserImageRawFrame(InputImageRawFrame):
|
||||
return f"{self.name}(pts: {pts}, user: {self.user_id}, source: {self.transport_source}, size: {self.size}, format: {self.format}, text: {self.text}, append_to_context: {self.append_to_context})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantImageRawFrame(OutputImageRawFrame):
|
||||
"""Frame containing an image generated by the assistant.
|
||||
|
||||
Contains both the raw frame for display (superclass functionality) as well
|
||||
as the original image, which can get used directly in LLM contexts.
|
||||
|
||||
Parameters:
|
||||
original_data: The original image data, which can get used directly in
|
||||
an LLM context message without further encoding.
|
||||
original_mime_type: The MIME type of the original image data.
|
||||
"""
|
||||
|
||||
original_data: Optional[bytes] = None
|
||||
original_mime_type: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDTMFFrame(DTMFFrame, SystemFrame):
|
||||
"""DTMF keypress input frame from transport."""
|
||||
@@ -1454,7 +1557,7 @@ class EndTaskFrame(TaskFrame):
|
||||
reason: Optional reason for pushing an end frame.
|
||||
"""
|
||||
|
||||
reason: Optional[str] = None
|
||||
reason: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
@@ -1472,7 +1575,7 @@ class CancelTaskFrame(TaskFrame):
|
||||
reason: Optional reason for pushing a cancel frame.
|
||||
"""
|
||||
|
||||
reason: Optional[str] = None
|
||||
reason: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
@@ -1551,7 +1654,7 @@ class EndFrame(ControlFrame):
|
||||
reason: Optional reason for pushing an end frame.
|
||||
"""
|
||||
|
||||
reason: Optional[str] = None
|
||||
reason: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(reason: {self.reason})"
|
||||
@@ -1650,6 +1753,45 @@ class LLMFullResponseEndFrame(ControlFrame):
|
||||
self.skip_tts = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallInProgressFrame(ControlFrame, UninterruptibleFrame):
|
||||
"""Frame signaling that a function call is currently executing.
|
||||
|
||||
This is an uninterruptible frame because we always want to update the
|
||||
context.
|
||||
|
||||
Parameters:
|
||||
function_name: Name of the function being executed.
|
||||
tool_call_id: Unique identifier for this function call.
|
||||
arguments: Arguments passed to the function.
|
||||
cancel_on_interruption: Whether to cancel this call if interrupted.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Any
|
||||
cancel_on_interruption: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionFullResponseStartFrame(LLMFullResponseStartFrame):
|
||||
"""Frame indicating the beginning of a vision model response.
|
||||
|
||||
Used to indicate the beginning of a vision model response. Followed by one
|
||||
or more VisionTextFrames and a final VisionFullResponseEndFrame.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionFullResponseEndFrame(LLMFullResponseEndFrame):
|
||||
"""Frame indicating the end of a Vision model response."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSStartedFrame(ControlFrame):
|
||||
"""Frame indicating the beginning of a TTS response.
|
||||
|
||||
@@ -15,8 +15,8 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -36,7 +36,7 @@ class UserBotLatencyLogObserver(BaseObserver):
|
||||
to calculate response latencies.
|
||||
"""
|
||||
super().__init__()
|
||||
self._processed_frames = set()
|
||||
self._user_bot_latency_processed_frames = set()
|
||||
self._user_stopped_time = 0
|
||||
self._latencies = []
|
||||
|
||||
@@ -51,14 +51,14 @@ class UserBotLatencyLogObserver(BaseObserver):
|
||||
return
|
||||
|
||||
# Skip already processed frames
|
||||
if data.frame.id in self._processed_frames:
|
||||
if data.frame.id in self._user_bot_latency_processed_frames:
|
||||
return
|
||||
|
||||
self._processed_frames.add(data.frame.id)
|
||||
self._user_bot_latency_processed_frames.add(data.frame.id)
|
||||
|
||||
if isinstance(data.frame, UserStartedSpeakingFrame):
|
||||
if isinstance(data.frame, VADUserStartedSpeakingFrame):
|
||||
self._user_stopped_time = 0
|
||||
elif isinstance(data.frame, UserStoppedSpeakingFrame):
|
||||
elif isinstance(data.frame, VADUserStoppedSpeakingFrame):
|
||||
self._user_stopped_time = time.time()
|
||||
elif isinstance(data.frame, (EndFrame, CancelFrame)):
|
||||
self._log_summary()
|
||||
|
||||
@@ -150,21 +150,29 @@ class LLMContext:
|
||||
|
||||
Args:
|
||||
role: The role of this message (defaults to "user").
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
format: Image format (e.g., 'RGB', 'RGBA', or, if already encoded,
|
||||
the MIME type like 'image/jpeg').
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
# Format is a mime type: image is already encoded
|
||||
image_already_encoded = format.startswith("image/")
|
||||
|
||||
def encode_image():
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
if image_already_encoded:
|
||||
bytes = image
|
||||
else:
|
||||
# Encode to JPEG
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
bytes = buffer.getvalue()
|
||||
encoded_image = base64.b64encode(bytes).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
encoded_image = await asyncio.to_thread(encode_image)
|
||||
|
||||
url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
url = f"data:{format if image_already_encoded else 'image/jpeg'};base64,{encoded_image}"
|
||||
|
||||
return LLMContext.create_image_url_message(role=role, url=url, text=text)
|
||||
|
||||
@@ -179,13 +187,12 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
content = [{"type": "text", "text": text}]
|
||||
|
||||
async def encode_audio():
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
|
||||
with io.BytesIO() as buffer:
|
||||
@@ -195,7 +202,7 @@ class LLMContext:
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_audio
|
||||
|
||||
encoded_audio = await asyncio.to_thread(encode_audio)
|
||||
@@ -334,18 +341,26 @@ class LLMContext:
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
async def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
|
||||
self,
|
||||
*,
|
||||
format: str,
|
||||
size: tuple[int, int],
|
||||
image: bytes,
|
||||
text: Optional[str] = None,
|
||||
role: str = "user",
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
format: Image format (e.g., 'RGB', 'RGBA', or, if already encoded,
|
||||
the MIME type like 'image/jpeg').
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
role: The role of this message (defaults to "user").
|
||||
"""
|
||||
message = await LLMContext.create_image_message(
|
||||
format=format, size=size, image=image, text=text
|
||||
role=role, format=format, size=size, image=image, text=text
|
||||
)
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
AssistantImageRawFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -47,6 +48,9 @@ from pipecat.frames.frames import (
|
||||
LLMRunFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
@@ -592,6 +596,10 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
|
||||
self._context_updated_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
self._thought_aggregation_enabled = False
|
||||
self._thought_llm: str = ""
|
||||
self._thought_aggregation: List[TextPartForConcatenation] = []
|
||||
|
||||
@property
|
||||
def has_function_calls_in_progress(self) -> bool:
|
||||
"""Check if there are any function calls currently in progress.
|
||||
@@ -601,6 +609,17 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"""
|
||||
return bool(self._function_calls_in_progress)
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state."""
|
||||
await super().reset()
|
||||
await self._reset_thought_aggregation() # Just to be safe
|
||||
|
||||
async def _reset_thought_aggregation(self):
|
||||
"""Reset the thought aggregation state."""
|
||||
self._thought_aggregation_enabled = False
|
||||
self._thought_llm = ""
|
||||
self._thought_aggregation = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for assistant response aggregation and function call management.
|
||||
|
||||
@@ -619,6 +638,12 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
elif isinstance(frame, LLMThoughtStartFrame):
|
||||
await self._handle_thought_start(frame)
|
||||
elif isinstance(frame, LLMThoughtTextFrame):
|
||||
await self._handle_thought_text(frame)
|
||||
elif isinstance(frame, LLMThoughtEndFrame):
|
||||
await self._handle_thought_end(frame)
|
||||
elif isinstance(frame, LLMRunFrame):
|
||||
await self._handle_llm_run(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
@@ -639,6 +664,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._handle_function_call_cancel(frame)
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
await self._handle_user_image_frame(frame)
|
||||
elif isinstance(frame, AssistantImageRawFrame):
|
||||
await self._handle_assistant_image_frame(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self.push_aggregation()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -803,6 +830,24 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_aggregation()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_assistant_image_frame(self, frame: AssistantImageRawFrame):
|
||||
logger.debug(f"{self} Appending AssistantImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
if frame.original_data and frame.original_mime_type:
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.original_mime_type,
|
||||
size=frame.size, # Technically doesn't matter, since already encoded
|
||||
image=frame.original_data,
|
||||
role="assistant",
|
||||
)
|
||||
else:
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started += 1
|
||||
|
||||
@@ -824,6 +869,47 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_thought_start(self, frame: LLMThoughtStartFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
await self._reset_thought_aggregation()
|
||||
self._thought_aggregation_enabled = frame.append_to_context
|
||||
self._thought_llm = frame.llm
|
||||
|
||||
async def _handle_thought_text(self, frame: LLMThoughtTextFrame):
|
||||
if not self._started or not self._thought_aggregation_enabled:
|
||||
return
|
||||
|
||||
# Make sure we really have text (spaces count, too!)
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
|
||||
self._thought_aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_thought_end(self, frame: LLMThoughtEndFrame):
|
||||
if not self._started or not self._thought_aggregation_enabled:
|
||||
return
|
||||
|
||||
thought = concatenate_aggregated_text(self._thought_aggregation)
|
||||
llm = self._thought_llm
|
||||
await self._reset_thought_aggregation()
|
||||
|
||||
self._context.add_message(
|
||||
LLMSpecificMessage(
|
||||
llm=llm,
|
||||
message={
|
||||
"type": "thought",
|
||||
"text": thought,
|
||||
"signature": frame.signature,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ management, and frame flow control mechanisms.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Sequence, Tuple, Type
|
||||
@@ -32,6 +33,7 @@ from pipecat.frames.frames import (
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
UninterruptibleFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
|
||||
from pipecat.observers.base_observer import BaseObserver, FrameProcessed, FramePushed
|
||||
@@ -210,6 +212,7 @@ class FrameProcessor(BaseObject):
|
||||
# The input task that handles all types of frames. It processes system
|
||||
# frames right away and queues non-system frames for later processing.
|
||||
self.__should_block_system_frames = False
|
||||
self.__input_queue = FrameProcessorQueue()
|
||||
self.__input_event: Optional[asyncio.Event] = None
|
||||
self.__input_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
@@ -219,8 +222,10 @@ class FrameProcessor(BaseObject):
|
||||
# called. To resume processing frames we need to call
|
||||
# `resume_processing_frames()` which will wake up the event.
|
||||
self.__should_block_frames = False
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__process_event: Optional[asyncio.Event] = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_current_frame: Optional[Frame] = None
|
||||
|
||||
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
|
||||
# Then we wait for the corresponding `InterruptionFrame` to travel from
|
||||
@@ -677,7 +682,17 @@ class FrameProcessor(BaseObject):
|
||||
if not error.processor:
|
||||
error.processor = self
|
||||
await self._call_event_handler("on_error", error)
|
||||
logger.error(f"{error.processor} error: {error.error}")
|
||||
|
||||
if error.exception:
|
||||
tb = traceback.extract_tb(error.exception.__traceback__)
|
||||
last = tb[-1]
|
||||
error_message = (
|
||||
f"{error.processor} exception ({last.filename}:{last.lineno}): {error.error}"
|
||||
)
|
||||
else:
|
||||
error_message = f"{error.processor} error: {error.error}"
|
||||
|
||||
logger.error(error_message)
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -794,8 +809,12 @@ class FrameProcessor(BaseObject):
|
||||
# interruption). Instead we just drain the queue because this is
|
||||
# an interruption.
|
||||
self.__reset_process_task()
|
||||
elif isinstance(self.__process_current_frame, UninterruptibleFrame):
|
||||
# We don't want to cancel UninterruptibleFrame, so we simply
|
||||
# cleanup the queue.
|
||||
self.__reset_process_queue()
|
||||
else:
|
||||
# Cancel and re-create the process task including the queue.
|
||||
# Cancel and re-create the process task.
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
@@ -861,7 +880,6 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
if not self.__input_frame_task:
|
||||
self.__input_event = asyncio.Event()
|
||||
self.__input_queue = FrameProcessorQueue()
|
||||
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
|
||||
|
||||
async def __cancel_input_task(self):
|
||||
@@ -879,9 +897,7 @@ class FrameProcessor(BaseObject):
|
||||
return
|
||||
|
||||
if not self.__process_frame_task:
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = asyncio.Event()
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__reset_process_task()
|
||||
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
|
||||
|
||||
def __reset_process_task(self):
|
||||
@@ -891,10 +907,26 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = asyncio.Event()
|
||||
self.__reset_process_queue()
|
||||
|
||||
def __reset_process_queue(self):
|
||||
"""Reset non-system frame processing queue."""
|
||||
# Create a new queue to insert UninterruptibleFrame frames.
|
||||
new_queue = asyncio.Queue()
|
||||
|
||||
# Process current queue and keep UninterruptibleFrame frames.
|
||||
while not self.__process_queue.empty():
|
||||
self.__process_queue.get_nowait()
|
||||
item = self.__process_queue.get_nowait()
|
||||
if isinstance(item, UninterruptibleFrame):
|
||||
new_queue.put_nowait(item)
|
||||
self.__process_queue.task_done()
|
||||
|
||||
# Put back UninterruptibleFrame frames into our process queue.
|
||||
while not new_queue.empty():
|
||||
item = new_queue.get_nowait()
|
||||
self.__process_queue.put_nowait(item)
|
||||
new_queue.task_done()
|
||||
|
||||
async def __cancel_process_task(self):
|
||||
"""Cancel the non-system frame processing task."""
|
||||
if self.__process_frame_task:
|
||||
@@ -948,8 +980,12 @@ class FrameProcessor(BaseObject):
|
||||
async def __process_frame_task_handler(self):
|
||||
"""Handle non-system frames from the process queue."""
|
||||
while True:
|
||||
self.__process_current_frame = None
|
||||
|
||||
(frame, direction, callback) = await self.__process_queue.get()
|
||||
|
||||
self.__process_current_frame = frame
|
||||
|
||||
if self.__should_block_frames and self.__process_event:
|
||||
logger.trace(f"{self}: frame processing paused")
|
||||
await self.__process_event.wait()
|
||||
|
||||
@@ -31,6 +31,7 @@ from typing import (
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat import version as pipecat_version
|
||||
from pipecat.audio.utils import calculate_audio_volume
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
@@ -85,7 +86,7 @@ from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
|
||||
RTVI_PROTOCOL_VERSION = "1.0.0"
|
||||
RTVI_PROTOCOL_VERSION = "1.1.0"
|
||||
|
||||
RTVI_MESSAGE_LABEL = "rtvi-ai"
|
||||
RTVIMessageLiteral = Literal["rtvi-ai"]
|
||||
@@ -935,8 +936,8 @@ class RTVIObserverParams:
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
bot_output_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
@@ -1417,15 +1418,20 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._client_ready = True
|
||||
await self._call_event_handler("on_client_ready")
|
||||
|
||||
async def set_bot_ready(self):
|
||||
"""Mark the bot as ready and send the bot-ready message."""
|
||||
async def set_bot_ready(self, about: Mapping[str, Any] = None):
|
||||
"""Mark the bot as ready and send the bot-ready message.
|
||||
|
||||
Args:
|
||||
about: Optional information about the bot to include in the ready message.
|
||||
If left as None, the Pipecat library and version will be used.
|
||||
"""
|
||||
self._bot_ready = True
|
||||
# Only call the (deprecated) _update_config method if the we're using a
|
||||
# config (which is deprecated). Otherwise we'd always print an
|
||||
# unnecessary deprecation warning.
|
||||
if self._config.config:
|
||||
await self._update_config(self._config, False)
|
||||
await self._send_bot_ready()
|
||||
await self._send_bot_ready(about=about)
|
||||
|
||||
async def interrupt_bot(self):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
@@ -1873,14 +1879,21 @@ class RTVIProcessor(FrameProcessor):
|
||||
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
|
||||
await self.push_transport_message(message)
|
||||
|
||||
async def _send_bot_ready(self):
|
||||
"""Send the bot-ready message to the client."""
|
||||
async def _send_bot_ready(self, about: Mapping[str, Any] = None):
|
||||
"""Send the bot-ready message to the client.
|
||||
|
||||
Args:
|
||||
about: Optional information about the bot to include in the ready message.
|
||||
If left as None, the pipecat library and version will be used.
|
||||
"""
|
||||
config = None
|
||||
if self._client_version and self._client_version[0] < 1:
|
||||
config = self._config.config
|
||||
if not about:
|
||||
about = {"library": "pipecat-ai", "library_version": f"{pipecat_version()}"}
|
||||
message = RTVIBotReady(
|
||||
id=self._client_ready_id,
|
||||
data=RTVIBotReadyData(version=RTVI_PROTOCOL_VERSION, config=config),
|
||||
data=RTVIBotReadyData(version=RTVI_PROTOCOL_VERSION, about=about, config=config),
|
||||
)
|
||||
await self.push_transport_message(message)
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
ThoughtTranscriptionMessage,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
@@ -81,92 +85,98 @@ class UserTranscriptProcessor(BaseTranscriptProcessor):
|
||||
|
||||
|
||||
class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes assistant TTS text frames into timestamped conversation messages.
|
||||
"""Processes assistant TTS text frames and LLM thought frames into timestamped messages.
|
||||
|
||||
This processor aggregates TTS text frames into complete utterances and emits them as
|
||||
transcript messages. Utterances are completed when:
|
||||
This processor aggregates both TTS text frames and LLM thought frames into
|
||||
complete utterances and thoughts, emitting them as transcript messages.
|
||||
|
||||
An assistant utterance is completed when:
|
||||
- The bot stops speaking (BotStoppedSpeakingFrame)
|
||||
- The bot is interrupted (InterruptionFrame)
|
||||
- The pipeline ends (EndFrame)
|
||||
- The pipeline ends (EndFrame, CancelFrame)
|
||||
|
||||
A thought is completed when:
|
||||
- The thought ends (LLMThoughtEndFrame)
|
||||
- The bot is interrupted (InterruptionFrame)
|
||||
- The pipeline ends (EndFrame, CancelFrame)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, *, process_thoughts: bool = False, **kwargs):
|
||||
"""Initialize processor with aggregation state.
|
||||
|
||||
Args:
|
||||
process_thoughts: Whether to process LLM thought frames. Defaults to False.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._current_text_parts: List[TextPartForConcatenation] = []
|
||||
self._aggregation_start_time: Optional[str] = None
|
||||
|
||||
async def _emit_aggregated_text(self):
|
||||
self._process_thoughts = process_thoughts
|
||||
self._current_assistant_text_parts: List[TextPartForConcatenation] = []
|
||||
self._assistant_text_start_time: Optional[str] = None
|
||||
|
||||
self._current_thought_parts: List[TextPartForConcatenation] = []
|
||||
self._thought_start_time: Optional[str] = None
|
||||
self._thought_active = False
|
||||
|
||||
async def _emit_aggregated_assistant_text(self):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
This method uses a heuristic to automatically detect whether text fragments
|
||||
contain embedded spacing (spaces at the beginning or end of fragments) or not,
|
||||
and applies the appropriate joining strategy. It handles fragments from different
|
||||
TTS services with different formatting patterns.
|
||||
|
||||
Examples:
|
||||
Fragments with embedded spacing (concatenated)::
|
||||
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: [" there"] # Leading space
|
||||
TTSTextFrame: ["!"]
|
||||
TTSTextFrame: [" How"] # Leading space
|
||||
TTSTextFrame: ["'s"]
|
||||
TTSTextFrame: [" it"] # Leading space
|
||||
|
||||
Result: "Hello there! How's it"
|
||||
|
||||
Fragments with trailing spaces (concatenated)::
|
||||
|
||||
TTSTextFrame: ["Hel"]
|
||||
TTSTextFrame: ["lo "] # Trailing space
|
||||
TTSTextFrame: ["to "] # Trailing space
|
||||
TTSTextFrame: ["you"]
|
||||
|
||||
Result: "Hello to you"
|
||||
|
||||
Word-by-word fragments without spacing (joined with spaces)::
|
||||
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: ["there"]
|
||||
TTSTextFrame: ["how"]
|
||||
TTSTextFrame: ["are"]
|
||||
TTSTextFrame: ["you"]
|
||||
|
||||
Result: "Hello there how are you"
|
||||
This method aggregates text fragments that may arrive in multiple
|
||||
TTSTextFrame instances and emits them as a single TranscriptionMessage.
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = concatenate_aggregated_text(self._current_text_parts)
|
||||
if self._current_assistant_text_parts and self._assistant_text_start_time:
|
||||
content = concatenate_aggregated_text(self._current_assistant_text_parts)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
timestamp=self._aggregation_start_time,
|
||||
timestamp=self._assistant_text_start_time,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.trace("No content to emit after stripping whitespace")
|
||||
|
||||
# Reset aggregation state
|
||||
self._current_text_parts = []
|
||||
self._aggregation_start_time = None
|
||||
self._current_assistant_text_parts = []
|
||||
self._assistant_text_start_time = None
|
||||
|
||||
async def _emit_aggregated_thought(self):
|
||||
"""Aggregates and emits thought text fragments as a thought transcript message.
|
||||
|
||||
This method aggregates thought fragments that may arrive in multiple
|
||||
LLMThoughtTextFrame instances and emits them as a single ThoughtTranscriptionMessage.
|
||||
"""
|
||||
if self._current_thought_parts and self._thought_start_time:
|
||||
content = concatenate_aggregated_text(self._current_thought_parts)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated thought message: {content}")
|
||||
message = ThoughtTranscriptionMessage(
|
||||
content=content,
|
||||
timestamp=self._thought_start_time,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.trace("No thought content to emit after stripping whitespace")
|
||||
|
||||
# Reset aggregation state
|
||||
self._current_thought_parts = []
|
||||
self._thought_start_time = None
|
||||
self._thought_active = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames into assistant conversation messages.
|
||||
"""Process frames into assistant conversation messages and thought messages.
|
||||
|
||||
Handles different frame types:
|
||||
|
||||
- TTSTextFrame: Aggregates text for current utterance
|
||||
- LLMThoughtStartFrame: Begins aggregating a new thought
|
||||
- LLMThoughtTextFrame: Aggregates text for current thought
|
||||
- LLMThoughtEndFrame: Completes current thought
|
||||
- BotStoppedSpeakingFrame: Completes current utterance
|
||||
- InterruptionFrame: Completes current utterance due to interruption
|
||||
- EndFrame: Completes current utterance at pipeline end
|
||||
- CancelFrame: Completes current utterance due to cancellation
|
||||
- InterruptionFrame: Completes current utterance and thought due to interruption
|
||||
- EndFrame: Completes current utterance and thought at pipeline end
|
||||
- CancelFrame: Completes current utterance and thought due to cancellation
|
||||
|
||||
Args:
|
||||
frame: Input frame to process.
|
||||
@@ -178,14 +188,40 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
# Push frame first otherwise our emitted transcription update frame
|
||||
# might get cleaned up.
|
||||
await self.push_frame(frame, direction)
|
||||
# Emit accumulated text with interruptions
|
||||
await self._emit_aggregated_text()
|
||||
# Emit accumulated text and thought with interruptions
|
||||
await self._emit_aggregated_assistant_text()
|
||||
if self._process_thoughts and self._thought_active:
|
||||
await self._emit_aggregated_thought()
|
||||
elif isinstance(frame, LLMThoughtStartFrame):
|
||||
# Start a new thought
|
||||
if self._process_thoughts:
|
||||
self._thought_active = True
|
||||
self._thought_start_time = time_now_iso8601()
|
||||
self._current_thought_parts = []
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMThoughtTextFrame):
|
||||
# Aggregate thought text if we have an active thought
|
||||
if self._process_thoughts and self._thought_active:
|
||||
self._current_thought_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMThoughtEndFrame):
|
||||
# Emit accumulated thought when thought ends
|
||||
if self._process_thoughts and self._thought_active:
|
||||
await self._emit_aggregated_thought()
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
# Start timestamp on first text part
|
||||
if not self._aggregation_start_time:
|
||||
self._aggregation_start_time = time_now_iso8601()
|
||||
if not self._assistant_text_start_time:
|
||||
self._assistant_text_start_time = time_now_iso8601()
|
||||
|
||||
self._current_text_parts.append(
|
||||
self._current_assistant_text_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
@@ -195,7 +231,10 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, (BotStoppedSpeakingFrame, EndFrame)):
|
||||
# Emit accumulated text when bot finishes speaking or pipeline ends.
|
||||
await self._emit_aggregated_text()
|
||||
await self._emit_aggregated_assistant_text()
|
||||
# Emit accumulated thought at pipeline end if still active
|
||||
if isinstance(frame, EndFrame) and self._process_thoughts and self._thought_active:
|
||||
await self._emit_aggregated_thought()
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
@@ -206,7 +245,8 @@ class TranscriptProcessor:
|
||||
"""Factory for creating and managing transcript processors.
|
||||
|
||||
Provides unified access to user and assistant transcript processors
|
||||
with shared event handling.
|
||||
with shared event handling. The assistant processor handles both TTS text
|
||||
and LLM thought frames.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -221,7 +261,7 @@ class TranscriptProcessor:
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
transcript.assistant_tts(), # Assistant transcripts
|
||||
transcript.assistant(), # Assistant transcripts (including thoughts)
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
@@ -231,8 +271,14 @@ class TranscriptProcessor:
|
||||
print(f"New messages: {frame.messages}")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize factory."""
|
||||
def __init__(self, *, process_thoughts: bool = False):
|
||||
"""Initialize factory.
|
||||
|
||||
Args:
|
||||
process_thoughts: Whether the assistant processor should handle LLM thought
|
||||
frames. Defaults to False.
|
||||
"""
|
||||
self._process_thoughts = process_thoughts
|
||||
self._user_processor = None
|
||||
self._assistant_processor = None
|
||||
self._event_handlers = {}
|
||||
@@ -267,7 +313,9 @@ class TranscriptProcessor:
|
||||
The assistant transcript processor instance.
|
||||
"""
|
||||
if self._assistant_processor is None:
|
||||
self._assistant_processor = AssistantTranscriptProcessor(**kwargs)
|
||||
self._assistant_processor = AssistantTranscriptProcessor(
|
||||
process_thoughts=self._process_thoughts, **kwargs
|
||||
)
|
||||
# Apply any registered event handlers
|
||||
for event_name, handler in self._event_handlers.items():
|
||||
|
||||
|
||||
@@ -171,6 +171,7 @@ def _create_server_app(
|
||||
esp32_mode: bool = False,
|
||||
whatsapp_enabled: bool = False,
|
||||
folder: Optional[str] = None,
|
||||
dialin_enabled: bool = False,
|
||||
):
|
||||
"""Create FastAPI app with transport-specific routes."""
|
||||
app = FastAPI()
|
||||
@@ -189,7 +190,7 @@ def _create_server_app(
|
||||
if whatsapp_enabled:
|
||||
_setup_whatsapp_routes(app)
|
||||
elif transport_type == "daily":
|
||||
_setup_daily_routes(app)
|
||||
_setup_daily_routes(app, dialin_enabled=dialin_enabled)
|
||||
elif transport_type in TELEPHONY_TRANSPORTS:
|
||||
_setup_telephony_routes(app, transport_type=transport_type, proxy=proxy)
|
||||
else:
|
||||
@@ -533,8 +534,13 @@ def _setup_whatsapp_routes(app: FastAPI):
|
||||
_add_lifespan_to_app(app, whatsapp_lifespan)
|
||||
|
||||
|
||||
def _setup_daily_routes(app: FastAPI):
|
||||
"""Set up Daily-specific routes."""
|
||||
def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
"""Set up Daily-specific routes.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
dialin_enabled: If True, adds /daily-dialin-webhook endpoint for PSTN dial-in handling
|
||||
"""
|
||||
|
||||
@app.get("/")
|
||||
async def create_room_and_start_agent():
|
||||
@@ -639,6 +645,116 @@ def _setup_daily_routes(app: FastAPI):
|
||||
|
||||
return result
|
||||
|
||||
if dialin_enabled:
|
||||
|
||||
@app.post("/daily-dialin-webhook")
|
||||
async def handle_dialin_webhook(request: Request):
|
||||
"""Handle incoming Daily PSTN dial-in webhook.
|
||||
|
||||
This endpoint mimics Pipecat Cloud's dial-in webhook handler.
|
||||
It receives Daily webhook data, creates a SIP-enabled room, and starts the bot.
|
||||
|
||||
Expected webhook payload::
|
||||
|
||||
{
|
||||
"From": "+15551234567",
|
||||
"To": "+15559876543",
|
||||
"callId": "uuid-call-id",
|
||||
"callDomain": "uuid-call-domain",
|
||||
"sipHeaders": {...} // optional
|
||||
}
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"dailyRoom": "https://...",
|
||||
"dailyToken": "...",
|
||||
"sessionId": "uuid"
|
||||
}
|
||||
"""
|
||||
logger.debug("Received Daily dial-in webhook")
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
logger.debug(f"Webhook data: {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse webhook data: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
# Handle webhook verification test (sent by Daily when configuring webhook)
|
||||
if data.get("test") or data.get("Test"):
|
||||
logger.debug("Webhook verification test received")
|
||||
return {"status": "OK"}
|
||||
|
||||
# Validate required fields
|
||||
if not all(key in data for key in ["From", "To", "callId", "callDomain"]):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing required fields: From, To, callId, callDomain",
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pipecat.runner.daily import configure
|
||||
from pipecat.runner.types import DailyDialinRequest, DialinSettings
|
||||
|
||||
# Create Daily room with SIP capabilities
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
room_config = await configure(session, sip_caller_phone=data.get("From"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Daily room: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create Daily room: {str(e)}"
|
||||
)
|
||||
|
||||
# Get Daily API URL from environment, fallback to production
|
||||
daily_api_url = os.getenv("DAILY_API_URL", "https://api.daily.co/v1")
|
||||
|
||||
# Get Daily API key from environment
|
||||
daily_api_key = os.getenv("DAILY_API_KEY")
|
||||
if not daily_api_key:
|
||||
logger.error("DAILY_API_KEY not found in environment")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="DAILY_API_KEY not configured on server"
|
||||
)
|
||||
|
||||
# Prepare dial-in settings matching Pipecat Cloud structure
|
||||
dialin_settings = DialinSettings(
|
||||
call_id=data.get("callId"),
|
||||
call_domain=data.get("callDomain"),
|
||||
To=data.get("To"),
|
||||
From=data.get("From"),
|
||||
sip_headers=data.get("sipHeaders"),
|
||||
)
|
||||
|
||||
# Create request body matching Pipecat Cloud payload
|
||||
request_body = DailyDialinRequest(
|
||||
dialin_settings=dialin_settings,
|
||||
daily_api_key=daily_api_key,
|
||||
daily_api_url=daily_api_url,
|
||||
)
|
||||
|
||||
# Start bot with dial-in context
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = DailyRunnerArguments(
|
||||
room_url=room_config.room_url,
|
||||
token=room_config.token,
|
||||
body=request_body.model_dump(),
|
||||
)
|
||||
|
||||
asyncio.create_task(bot_module.bot(runner_args))
|
||||
|
||||
# Generate session ID
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Return response matching Pipecat Cloud format
|
||||
return {
|
||||
"dailyRoom": room_config.room_url,
|
||||
"dailyToken": room_config.token,
|
||||
"sessionId": session_id,
|
||||
}
|
||||
|
||||
|
||||
def _setup_telephony_routes(app: FastAPI, *, transport_type: str, proxy: str):
|
||||
"""Set up telephony-specific routes."""
|
||||
@@ -813,6 +929,12 @@ def main():
|
||||
default=False,
|
||||
help="Ensure requried WhatsApp environment variables are present",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dialin",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Daily PSTN dial-in webhook handling (requires Daily transport)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -832,6 +954,11 @@ def main():
|
||||
logger.error("For ESP32, you need to specify `--host IP` so we can do SDP munging.")
|
||||
return
|
||||
|
||||
# Validate dial-in requirements
|
||||
if args.dialin and args.transport != "daily":
|
||||
logger.error("--dialin flag only works with Daily transport (-t daily)")
|
||||
return
|
||||
|
||||
# Log level
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level="TRACE" if args.verbose else "DEBUG")
|
||||
@@ -860,7 +987,13 @@ def main():
|
||||
elif args.transport == "daily":
|
||||
print()
|
||||
print(f"🚀 Bot ready!")
|
||||
print(f" → Open http://{args.host}:{args.port} in your browser to start a session")
|
||||
if args.dialin:
|
||||
print(
|
||||
f" → Daily dial-in webhook: http://{args.host}:{args.port}/daily-dialin-webhook"
|
||||
)
|
||||
print(f" → Configure this URL in your Daily phone number settings")
|
||||
else:
|
||||
print(f" → Open http://{args.host}:{args.port} in your browser to start a session")
|
||||
print()
|
||||
|
||||
RUNNER_DOWNLOADS_FOLDER = args.folder
|
||||
@@ -875,6 +1008,7 @@ def main():
|
||||
esp32_mode=args.esp32,
|
||||
whatsapp_enabled=args.whatsapp,
|
||||
folder=args.folder,
|
||||
dialin_enabled=args.dialin,
|
||||
)
|
||||
|
||||
# Run the server
|
||||
|
||||
@@ -11,9 +11,48 @@ information to bot functions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DialinSettings(BaseModel):
|
||||
"""Dial-in settings from the Daily webhook.
|
||||
|
||||
This model matches the structure sent by Pipecat Cloud and Daily.co webhooks
|
||||
for incoming PSTN/SIP calls.
|
||||
|
||||
Parameters:
|
||||
call_id: Unique identifier for the call (UUID representing sessionId in SIP Network)
|
||||
call_domain: Daily domain for the call (UUID representing Daily Domain on SIP Network)
|
||||
To: The dialed phone number (optional)
|
||||
From: The caller's phone number (optional)
|
||||
sip_headers: Optional SIP headers from the call
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
call_domain: str
|
||||
To: Optional[str] = None
|
||||
From: Optional[str] = None
|
||||
sip_headers: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class DailyDialinRequest(BaseModel):
|
||||
"""Request data for Daily PSTN dial-in requests.
|
||||
|
||||
This is the structure passed in runner_args.body for dial-in calls.
|
||||
It matches the payload structure from Pipecat Cloud's dial-in webhook handler.
|
||||
|
||||
Parameters:
|
||||
dialin_settings: Dial-in configuration including call_id, call_domain, To, From
|
||||
daily_api_key: Daily API key for pinlessCallUpdate (required for dial-in)
|
||||
daily_api_url: Daily API URL (staging or production)
|
||||
"""
|
||||
|
||||
dialin_settings: DialinSettings
|
||||
daily_api_key: str
|
||||
daily_api_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -17,7 +17,7 @@ import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
@@ -40,6 +40,9 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
@@ -110,6 +113,24 @@ class AnthropicLLMService(LLMService):
|
||||
# Overriding the default adapter to use the Anthropic one.
|
||||
adapter_class = AnthropicLLMAdapter
|
||||
|
||||
class ThinkingConfig(BaseModel):
|
||||
"""Configuration for extended thinking.
|
||||
|
||||
Parameters:
|
||||
type: Type of thinking mode (currently only "enabled" or "disabled").
|
||||
budget_tokens: Maximum number of tokens for thinking.
|
||||
With today's models, the minimum is 1024.
|
||||
Only allowed if type is "enabled".
|
||||
"""
|
||||
|
||||
# Why `| str` here? To not break compatibility in case Anthropic adds
|
||||
# more types in the future.
|
||||
type: Literal["enabled", "disabled"] | str
|
||||
|
||||
# Why not enforce minimnum of 1024 here? To not break compatibility in
|
||||
# case Anthropic changes this requirement in the future.
|
||||
budget_tokens: int
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Anthropic model inference.
|
||||
|
||||
@@ -124,6 +145,10 @@ class AnthropicLLMService(LLMService):
|
||||
temperature: Sampling temperature between 0.0 and 1.0.
|
||||
top_k: Top-k sampling parameter.
|
||||
top_p: Top-p sampling parameter between 0.0 and 1.0.
|
||||
thinking: Extended thinking configuration.
|
||||
Enabling extended thinking causes the model to spend more time "thinking" before responding.
|
||||
It also causes this service to emit LLMThinking*Frames during response generation.
|
||||
Extended thinking is disabled by default.
|
||||
extra: Additional parameters to pass to the API.
|
||||
"""
|
||||
|
||||
@@ -133,6 +158,9 @@ class AnthropicLLMService(LLMService):
|
||||
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
thinking: Optional["AnthropicLLMService.ThinkingConfig"] = Field(
|
||||
default_factory=lambda: NOT_GIVEN
|
||||
)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def model_post_init(self, __context):
|
||||
@@ -191,6 +219,7 @@ class AnthropicLLMService(LLMService):
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
"thinking": params.thinking,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
@@ -238,28 +267,43 @@ class AnthropicLLMService(LLMService):
|
||||
"""
|
||||
messages = []
|
||||
system = NOT_GIVEN
|
||||
tools = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(
|
||||
invocation_params = adapter.get_llm_invocation_params(
|
||||
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system"]
|
||||
messages = invocation_params["messages"]
|
||||
system = invocation_params["system"]
|
||||
tools = invocation_params["tools"]
|
||||
else:
|
||||
context = AnthropicLLMContext.upgrade_to_anthropic(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", NOT_GIVEN)
|
||||
tools = context.tools or []
|
||||
|
||||
# Build params using the same method as streaming completions
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"stream": False,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"messages": messages,
|
||||
"system": system,
|
||||
"tools": tools,
|
||||
"betas": ["interleaved-thinking-2025-05-14"],
|
||||
}
|
||||
if self._settings["thinking"]:
|
||||
params["thinking"] = self._settings["thinking"].model_dump(exclude_unset=True)
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
# LLM completion
|
||||
response = await self._client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
system=system,
|
||||
max_tokens=8192,
|
||||
stream=False,
|
||||
)
|
||||
response = await self._client.beta.messages.create(**params)
|
||||
|
||||
return response.content[0].text
|
||||
return next((block.text for block in response.content if hasattr(block, "text")), None)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
@@ -354,12 +398,21 @@ class AnthropicLLMService(LLMService):
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
# Add thinking parameter if set
|
||||
if self._settings["thinking"]:
|
||||
params["thinking"] = self._settings["thinking"].model_dump(exclude_unset=True)
|
||||
|
||||
# Messages, system, tools
|
||||
params.update(params_from_context)
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
response = await self._create_message_stream(self._client.messages.create, params)
|
||||
# "Interleaved thinking" needed to allow thinking between sequences
|
||||
# of function calls, when extended thinking is enabled.
|
||||
# Note that this requires us to use `client.beta`, below.
|
||||
params.update({"betas": ["interleaved-thinking-2025-05-14"]})
|
||||
|
||||
response = await self._create_message_stream(self._client.beta.messages.create, params)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -380,10 +433,21 @@ class AnthropicLLMService(LLMService):
|
||||
completion_tokens_estimate += self._estimate_tokens(
|
||||
event.delta.partial_json
|
||||
)
|
||||
elif hasattr(event.delta, "thinking"):
|
||||
await self.push_frame(LLMThoughtTextFrame(text=event.delta.thinking))
|
||||
elif hasattr(event.delta, "signature"):
|
||||
await self.push_frame(LLMThoughtEndFrame(signature=event.delta.signature))
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
tool_use_block = event.content_block
|
||||
json_accumulator = ""
|
||||
elif event.content_block.type == "thinking":
|
||||
await self.push_frame(
|
||||
LLMThoughtStartFrame(
|
||||
append_to_context=True,
|
||||
llm=self.get_llm_adapter().id_for_llm_specific_messages,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
event.type == "message_delta"
|
||||
and hasattr(event.delta, "stop_reason")
|
||||
|
||||
@@ -17,11 +17,10 @@ from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat import __version__ as pipecat_version
|
||||
from pipecat import version as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -30,7 +29,7 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
@@ -44,15 +43,15 @@ from .models import (
|
||||
)
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error('In order to use AssemblyAI, you need to `pip install "pipecat-ai[assemblyai]"`.')
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AssemblyAISTTService(STTService):
|
||||
class AssemblyAISTTService(WebsocketSTTService):
|
||||
"""AssemblyAI real-time speech-to-text service.
|
||||
|
||||
Provides real-time speech transcription using AssemblyAI's WebSocket API.
|
||||
@@ -80,15 +79,14 @@ class AssemblyAISTTService(STTService):
|
||||
vad_force_turn_endpoint: Whether to force turn endpoint on VAD stop. Defaults to True.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=connection_params.sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._language = language
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._connection_params = connection_params
|
||||
self._vad_force_turn_endpoint = vad_force_turn_endpoint
|
||||
|
||||
super().__init__(sample_rate=self._connection_params.sample_rate, **kwargs)
|
||||
|
||||
self._websocket = None
|
||||
self._termination_event = asyncio.Event()
|
||||
self._received_termination = False
|
||||
self._connected = False
|
||||
@@ -114,7 +112,7 @@ class AssemblyAISTTService(STTService):
|
||||
frame: Start frame to begin processing.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self._sample_rate * 2 / 1000)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -146,10 +144,11 @@ class AssemblyAISTTService(STTService):
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
await self._websocket.send(chunk)
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
await self._websocket.send(chunk)
|
||||
|
||||
yield None
|
||||
|
||||
@@ -164,7 +163,11 @@ class AssemblyAISTTService(STTService):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_ttfb_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
if self._vad_force_turn_endpoint:
|
||||
if (
|
||||
self._vad_force_turn_endpoint
|
||||
and self._websocket
|
||||
and self._websocket.state is State.OPEN
|
||||
):
|
||||
await self._websocket.send(json.dumps({"type": "ForceEndpoint"}))
|
||||
await self.start_processing_metrics()
|
||||
|
||||
@@ -191,27 +194,20 @@ class AssemblyAISTTService(STTService):
|
||||
return self._api_endpoint_base_url
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
ws_url = self._build_ws_url()
|
||||
headers = {
|
||||
"Authorization": self._api_key,
|
||||
"User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version})",
|
||||
}
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
)
|
||||
self._connected = True
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
"""Connect to the AssemblyAI service.
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
Establishes websocket connection and starts receive task.
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from AssemblyAI WebSocket and wait for termination message."""
|
||||
"""Disconnect from the AssemblyAI service.
|
||||
|
||||
Sends termination message, waits for acknowledgment, and cleans up.
|
||||
"""
|
||||
if not self._connected or not self._websocket:
|
||||
return
|
||||
|
||||
@@ -219,51 +215,96 @@ class AssemblyAISTTService(STTService):
|
||||
self._termination_event.clear()
|
||||
self._received_termination = False
|
||||
|
||||
if len(self._audio_buffer) > 0:
|
||||
await self._websocket.send(bytes(self._audio_buffer))
|
||||
self._audio_buffer.clear()
|
||||
|
||||
try:
|
||||
await self._websocket.send(json.dumps({"type": "Terminate"}))
|
||||
if self._websocket.state is State.OPEN:
|
||||
# Send any remaining audio
|
||||
if len(self._audio_buffer) > 0:
|
||||
await self._websocket.send(bytes(self._audio_buffer))
|
||||
self._audio_buffer.clear()
|
||||
|
||||
# Send termination message and wait for acknowledgment
|
||||
try:
|
||||
await asyncio.wait_for(self._termination_event.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
await self._websocket.send(json.dumps({"type": "Terminate"}))
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
try:
|
||||
await asyncio.wait_for(self._termination_event.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
# Clean up tasks and connection
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish the websocket connection to AssemblyAI."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to AssemblyAI WebSocket")
|
||||
|
||||
ws_url = self._build_ws_url()
|
||||
headers = {
|
||||
"Authorization": self._api_key,
|
||||
"User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version()})",
|
||||
}
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
)
|
||||
self._connected = True
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug(f"{self} Connected to AssemblyAI WebSocket")
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
await self.push_error(error_msg=f"Unable to connect to AssemblyAI: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to AssemblyAI."""
|
||||
try:
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from AssemblyAI WebSocket")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._connected = False
|
||||
self._receive_task = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
try:
|
||||
while self._connected:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
data = json.loads(message)
|
||||
await self._handle_message(data)
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
break
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
break
|
||||
def _get_websocket(self):
|
||||
"""Get the current WebSocket connection.
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
Returns:
|
||||
The WebSocket connection.
|
||||
|
||||
Raises:
|
||||
Exception: If WebSocket is not connected.
|
||||
"""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process websocket messages.
|
||||
|
||||
Continuously processes messages from the websocket connection.
|
||||
"""
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._handle_message(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
|
||||
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
|
||||
"""Parse a raw message into the appropriate message type."""
|
||||
|
||||
@@ -56,6 +56,17 @@ def language_to_async_language(language: Language) -> Optional[str]:
|
||||
Language.ES: "es",
|
||||
Language.DE: "de",
|
||||
Language.IT: "it",
|
||||
Language.PT: "pt",
|
||||
Language.NL: "nl",
|
||||
Language.AR: "ar",
|
||||
Language.RU: "ru",
|
||||
Language.RO: "ro",
|
||||
Language.JA: "ja",
|
||||
Language.HE: "he",
|
||||
Language.HY: "hy",
|
||||
Language.TR: "tr",
|
||||
Language.HI: "hi",
|
||||
Language.ZH: "zh",
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
@@ -74,7 +85,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
language: Language to use for synthesis.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
language: Optional[Language] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -83,7 +94,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
voice_id: str,
|
||||
version: str = "v1",
|
||||
url: str = "wss://api.async.ai/text_to_speech/websocket/ws",
|
||||
model: str = "asyncflow_v2.0",
|
||||
model: str = "asyncflow_multilingual_v1.0",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "pcm_s16le",
|
||||
container: str = "raw",
|
||||
@@ -99,7 +110,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
https://docs.async.ai/list-voices-16699698e0
|
||||
version: Async API version.
|
||||
url: WebSocket URL for Async TTS API.
|
||||
model: TTS model to use (e.g., "asyncflow_v2.0").
|
||||
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
|
||||
sample_rate: Audio sample rate.
|
||||
encoding: Audio encoding format.
|
||||
container: Audio container format.
|
||||
@@ -128,7 +139,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en",
|
||||
else None,
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
@@ -357,7 +368,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
language: Language to use for synthesis.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
language: Optional[Language] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -365,7 +376,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
model: str = "asyncflow_v2.0",
|
||||
model: str = "asyncflow_multilingual_v1.0",
|
||||
url: str = "https://api.async.ai",
|
||||
version: str = "v1",
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -380,7 +391,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
api_key: Async API key.
|
||||
voice_id: ID of the voice to use for synthesis.
|
||||
aiohttp_session: An aiohttp session for making HTTP requests.
|
||||
model: TTS model to use (e.g., "asyncflow_v2.0").
|
||||
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
|
||||
url: Base URL for Async API.
|
||||
version: API version string for Async API.
|
||||
sample_rate: Audio sample rate.
|
||||
@@ -404,7 +415,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en",
|
||||
else None,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self.set_model_name(model)
|
||||
|
||||
@@ -840,15 +840,13 @@ class AWSBedrockLLMService(LLMService):
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
|
||||
# Prepare request parameters
|
||||
# Prepare request parameters using the same method as streaming
|
||||
inference_config = self._build_inference_config()
|
||||
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"modelId": self.model_name,
|
||||
"messages": messages,
|
||||
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
|
||||
}
|
||||
|
||||
if inference_config:
|
||||
|
||||
@@ -157,6 +157,12 @@ class Params(BaseModel):
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
top_p: Nucleus sampling parameter.
|
||||
temperature: Sampling temperature for text generation.
|
||||
endpointing_sensitivity: Controls how quickly Nova Sonic decides the
|
||||
user has stopped speaking. Can be "LOW", "MEDIUM", or "HIGH", with
|
||||
"HIGH" being the most sensitive (i.e., causing the model to respond
|
||||
most quickly).
|
||||
If not set, uses the model's default behavior.
|
||||
Only supported with Nova 2 Sonic (the default model).
|
||||
"""
|
||||
|
||||
# Audio input
|
||||
@@ -174,6 +180,9 @@ class Params(BaseModel):
|
||||
top_p: Optional[float] = Field(default=0.9)
|
||||
temperature: Optional[float] = Field(default=0.7)
|
||||
|
||||
# Turn-taking
|
||||
endpointing_sensitivity: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMService(LLMService):
|
||||
"""AWS Nova Sonic speech-to-speech LLM service.
|
||||
@@ -192,8 +201,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
access_key_id: str,
|
||||
session_token: Optional[str] = None,
|
||||
region: str,
|
||||
model: str = "amazon.nova-sonic-v1:0",
|
||||
voice_id: str = "matthew", # matthew, tiffany, amy
|
||||
model: str = "amazon.nova-2-sonic-v1:0",
|
||||
voice_id: str = "matthew",
|
||||
params: Optional[Params] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[ToolsSchema] = None,
|
||||
@@ -207,8 +216,15 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
access_key_id: AWS access key ID for authentication.
|
||||
session_token: AWS session token for authentication.
|
||||
region: AWS region where the service is hosted.
|
||||
model: Model identifier. Defaults to "amazon.nova-sonic-v1:0".
|
||||
voice_id: Voice ID for speech synthesis. Options: matthew, tiffany, amy.
|
||||
Supported regions:
|
||||
- Nova 2 Sonic (the default model): "us-east-1", "us-west-2", "ap-northeast-1"
|
||||
- Nova Sonic (the older model): "us-east-1", "ap-northeast-1"
|
||||
model: Model identifier. Defaults to "amazon.nova-2-sonic-v1:0".
|
||||
voice_id: Voice ID for speech synthesis.
|
||||
Note that some voices are designed for use with a specific language.
|
||||
Options:
|
||||
- Nova 2 Sonic (the default model): see https://docs.aws.amazon.com/nova/latest/nova2-userguide/sonic-language-support.html
|
||||
- Nova Sonic (the older model): see https://docs.aws.amazon.com/nova/latest/userguide/available-voices.html.
|
||||
params: Model parameters for audio configuration and inference.
|
||||
system_instruction: System-level instruction for the model.
|
||||
tools: Available tools/functions for the model to use.
|
||||
@@ -232,6 +248,17 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._system_instruction = system_instruction
|
||||
self._tools = tools
|
||||
|
||||
# Validate endpointing_sensitivity parameter
|
||||
if (
|
||||
self._params.endpointing_sensitivity
|
||||
and not self._is_endpointing_sensitivity_supported()
|
||||
):
|
||||
logger.warning(
|
||||
f"endpointing_sensitivity is not supported for model '{model}' and will be ignored. "
|
||||
"This parameter is only supported starting with Nova 2 Sonic (amazon.nova-2-sonic-v1:0)."
|
||||
)
|
||||
self._params.endpointing_sensitivity = None
|
||||
|
||||
if not send_transcription_frames:
|
||||
import warnings
|
||||
|
||||
@@ -459,7 +486,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
# Check for set of completed function calls in the context
|
||||
for message in self._context.get_messages():
|
||||
if message.get("role") and message.get("content") != "IN_PROGRESS":
|
||||
if message.get("role") and message.get("content") not in ["IN_PROGRESS", "CANCELLED"]:
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if tool_call_id and tool_call_id not in self._completed_tool_calls:
|
||||
# Found a newly-completed function call - send the result to the service
|
||||
@@ -591,11 +618,33 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
)
|
||||
return BedrockRuntimeClient(config=config)
|
||||
|
||||
def _is_first_generation_sonic_model(self) -> bool:
|
||||
# Nova Sonic (the older model) is identified by "amazon.nova-sonic-v1:0"
|
||||
return self._model == "amazon.nova-sonic-v1:0"
|
||||
|
||||
def _is_endpointing_sensitivity_supported(self) -> bool:
|
||||
# endpointing_sensitivity is only supported with Nova 2 Sonic (and,
|
||||
# presumably, future models)
|
||||
return not self._is_first_generation_sonic_model()
|
||||
|
||||
def _is_assistant_response_trigger_needed(self) -> bool:
|
||||
# Assistant response trigger audio is only needed with the older model
|
||||
return self._is_first_generation_sonic_model()
|
||||
|
||||
#
|
||||
# LLM communication: input events (pipecat -> LLM)
|
||||
#
|
||||
|
||||
async def _send_session_start_event(self):
|
||||
turn_detection_config = (
|
||||
f""",
|
||||
"turnDetectionConfiguration": {{
|
||||
"endpointingSensitivity": "{self._params.endpointing_sensitivity}"
|
||||
}}"""
|
||||
if self._params.endpointing_sensitivity
|
||||
else ""
|
||||
)
|
||||
|
||||
session_start = f"""
|
||||
{{
|
||||
"event": {{
|
||||
@@ -604,7 +653,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"maxTokens": {self._params.max_tokens},
|
||||
"topP": {self._params.top_p},
|
||||
"temperature": {self._params.temperature}
|
||||
}}
|
||||
}}{turn_detection_config}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
@@ -1189,7 +1238,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
)
|
||||
|
||||
#
|
||||
# assistant response trigger (HACK)
|
||||
# assistant response trigger
|
||||
# HACK: only needed for the older Nova Sonic (as opposed to Nova 2 Sonic) model
|
||||
#
|
||||
|
||||
# Class variable
|
||||
@@ -1203,12 +1253,17 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
Sends a pre-recorded "ready" audio trigger to prompt the assistant
|
||||
to start speaking. This is useful for controlling conversation flow.
|
||||
|
||||
Returns:
|
||||
False if already triggering a response, True otherwise.
|
||||
"""
|
||||
if not self._is_assistant_response_trigger_needed():
|
||||
logger.warning(
|
||||
f"Assistant response trigger not needed for model '{self._model}'; skipping. "
|
||||
"An LLMRunFrame() should be sufficient to prompt the assistant to respond, "
|
||||
"assuming the context ends in a user message."
|
||||
)
|
||||
return
|
||||
|
||||
if self._triggering_assistant_response:
|
||||
return False
|
||||
return
|
||||
|
||||
self._triggering_assistant_response = True
|
||||
|
||||
|
||||
@@ -29,13 +29,12 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
@@ -44,7 +43,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AWSTranscribeSTTService(STTService):
|
||||
class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
"""AWS Transcribe Speech-to-Text service using WebSocket streaming.
|
||||
|
||||
Provides real-time speech transcription using AWS Transcribe's streaming API.
|
||||
@@ -99,9 +98,6 @@ class AWSTranscribeSTTService(STTService):
|
||||
"region": region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
self._ws_client = None
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self._connecting = False
|
||||
self._receive_task = None
|
||||
|
||||
def get_service_encoding(self, encoding: str) -> str:
|
||||
@@ -123,29 +119,9 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
Args:
|
||||
frame: Start frame signaling service initialization.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If WebSocket connection cannot be established after retries.
|
||||
"""
|
||||
await super().start(frame)
|
||||
logger.info("Starting AWS Transcribe service...")
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._connect()
|
||||
if self._ws_client and self._ws_client.state is State.OPEN:
|
||||
logger.info("Successfully established WebSocket connection")
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
|
||||
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and disconnect from AWS Transcribe.
|
||||
@@ -174,140 +150,127 @@ class AWSTranscribeSTTService(STTService):
|
||||
Yields:
|
||||
ErrorFrame: If processing fails or connection issues occur.
|
||||
"""
|
||||
try:
|
||||
# Ensure WebSocket is connected
|
||||
if not self._ws_client or self._ws_client.state is State.CLOSED:
|
||||
logger.debug("WebSocket not connected, attempting to reconnect...")
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
event_message = build_event_message(audio)
|
||||
|
||||
# Send the formatted event message
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
try:
|
||||
await self._ws_client.send(event_message)
|
||||
# Format the audio data according to AWS event stream format
|
||||
event_message = build_event_message(audio)
|
||||
|
||||
# Send the formatted event message
|
||||
await self._websocket.send(event_message)
|
||||
# Start metrics after first chunk sent
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"Connection closed while sending: {e}")
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
await self._disconnect()
|
||||
yield ErrorFrame(error=f"Error sending audio: {e}")
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
await self._disconnect()
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to AWS Transcribe with connection state management."""
|
||||
if self._ws_client and self._ws_client.state is State.OPEN and self._receive_task:
|
||||
logger.debug(f"{self} Already connected")
|
||||
return
|
||||
"""Connect to the AWS Transcribe service.
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._connecting:
|
||||
logger.debug(f"{self} Connection already in progress")
|
||||
return
|
||||
Establishes websocket connection and starts receive task.
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
try:
|
||||
self._connecting = True
|
||||
logger.debug(f"{self} Starting connection process...")
|
||||
|
||||
if self._ws_client:
|
||||
await self._disconnect()
|
||||
|
||||
language_code = self.language_to_service_language(
|
||||
Language(self._settings["language"])
|
||||
)
|
||||
if not language_code:
|
||||
raise ValueError(f"Unsupported language: {self._settings['language']}")
|
||||
|
||||
# Generate random websocket key
|
||||
websocket_key = "".join(
|
||||
random.choices(
|
||||
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
|
||||
)
|
||||
)
|
||||
|
||||
# Add required headers
|
||||
additional_headers = {
|
||||
"Origin": "https://localhost",
|
||||
"Sec-WebSocket-Key": websocket_key,
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Get presigned URL
|
||||
presigned_url = get_presigned_url(
|
||||
region=self._credentials["region"],
|
||||
credentials={
|
||||
"access_key": self._credentials["aws_access_key_id"],
|
||||
"secret_key": self._credentials["aws_secret_access_key"],
|
||||
"session_token": self._credentials["aws_session_token"],
|
||||
},
|
||||
language_code=language_code,
|
||||
media_encoding=self.get_service_encoding(
|
||||
self._settings["media_encoding"]
|
||||
), # Convert to AWS format
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
number_of_channels=self._settings["number_of_channels"],
|
||||
enable_partial_results_stabilization=True,
|
||||
partial_results_stability="high",
|
||||
show_speaker_label=self._settings["show_speaker_label"],
|
||||
enable_channel_identification=self._settings["enable_channel_identification"],
|
||||
)
|
||||
|
||||
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
|
||||
# Connect with the required headers and settings
|
||||
self._ws_client = await websocket_connect(
|
||||
presigned_url,
|
||||
additional_headers=additional_headers,
|
||||
subprotocols=["mqtt"],
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
compression=None,
|
||||
)
|
||||
|
||||
logger.debug(f"{self} WebSocket connected, starting receive task...")
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = self.create_task(self._receive_loop())
|
||||
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._connecting = False
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from AWS Transcribe."""
|
||||
"""Disconnect from the AWS Transcribe service.
|
||||
|
||||
Sends end-stream message and cleans up.
|
||||
"""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
try:
|
||||
if self._ws_client and self._ws_client.state is State.OPEN:
|
||||
# Send end-stream message
|
||||
# Send end-stream message before closing
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
try:
|
||||
end_stream = {"message-type": "event", "event": "end"}
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
await self._websocket.send(json.dumps(end_stream))
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error sending end-stream: {e}", exception=e)
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish the websocket connection to AWS Transcribe."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to AWS Transcribe WebSocket")
|
||||
|
||||
language_code = self.language_to_service_language(Language(self._settings["language"]))
|
||||
if not language_code:
|
||||
raise ValueError(f"Unsupported language: {self._settings['language']}")
|
||||
|
||||
# Generate random websocket key
|
||||
websocket_key = "".join(
|
||||
random.choices(
|
||||
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
|
||||
)
|
||||
)
|
||||
|
||||
# Add required headers
|
||||
additional_headers = {
|
||||
"Origin": "https://localhost",
|
||||
"Sec-WebSocket-Key": websocket_key,
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Get presigned URL
|
||||
presigned_url = get_presigned_url(
|
||||
region=self._credentials["region"],
|
||||
credentials={
|
||||
"access_key": self._credentials["aws_access_key_id"],
|
||||
"secret_key": self._credentials["aws_secret_access_key"],
|
||||
"session_token": self._credentials["aws_session_token"],
|
||||
},
|
||||
language_code=language_code,
|
||||
media_encoding=self.get_service_encoding(
|
||||
self._settings["media_encoding"]
|
||||
), # Convert to AWS format
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
number_of_channels=self._settings["number_of_channels"],
|
||||
enable_partial_results_stabilization=True,
|
||||
partial_results_stability="high",
|
||||
show_speaker_label=self._settings["show_speaker_label"],
|
||||
enable_channel_identification=self._settings["enable_channel_identification"],
|
||||
)
|
||||
|
||||
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
|
||||
# Connect with the required headers and settings
|
||||
self._websocket = await websocket_connect(
|
||||
presigned_url,
|
||||
additional_headers=additional_headers,
|
||||
subprotocols=["mqtt"],
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
compression=None,
|
||||
)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to AWS Transcribe: {e}", exception=e
|
||||
)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to AWS Transcribe."""
|
||||
try:
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from AWS Transcribe WebSocket")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._ws_client = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
@@ -471,16 +434,26 @@ class AWSTranscribeSTTService(STTService):
|
||||
):
|
||||
pass
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive and process messages from AWS Transcribe."""
|
||||
while True:
|
||||
if not self._ws_client or self._ws_client.state is State.CLOSED:
|
||||
logger.warning(f"{self} WebSocket closed in receive loop")
|
||||
break
|
||||
def _get_websocket(self):
|
||||
"""Get the current WebSocket connection.
|
||||
|
||||
Returns:
|
||||
The WebSocket connection.
|
||||
|
||||
Raises:
|
||||
Exception: If WebSocket is not connected.
|
||||
"""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process websocket messages.
|
||||
|
||||
Continuously processes messages from the websocket connection.
|
||||
"""
|
||||
async for response in self._get_websocket():
|
||||
try:
|
||||
response = await self._ws_client.recv()
|
||||
|
||||
headers, payload = decode_event(response)
|
||||
|
||||
if headers.get(":message-type") == "event":
|
||||
@@ -527,11 +500,5 @@ class AWSTranscribeSTTService(STTService):
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
await self.push_error(
|
||||
error_msg=f"WebSocket connection closed in receive loop", exception=e
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
break
|
||||
logger.warning(f"Error processing message: {e}")
|
||||
|
||||
@@ -10,7 +10,6 @@ This module provides a WebSocket-based STT service that integrates with
|
||||
the Cartesia Live transcription API for real-time speech recognition.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import urllib.parse
|
||||
from typing import AsyncGenerator, Optional
|
||||
@@ -20,7 +19,6 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -160,20 +158,16 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
merged_options = default_options
|
||||
merged_options = default_options.to_dict()
|
||||
if live_options:
|
||||
merged_options_dict = default_options.to_dict()
|
||||
merged_options_dict.update(live_options.to_dict())
|
||||
merged_options = CartesiaLiveOptions(
|
||||
**{
|
||||
k: v
|
||||
for k, v in merged_options_dict.items()
|
||||
if not isinstance(v, str) or v != "None"
|
||||
}
|
||||
)
|
||||
merged_options.update(live_options.to_dict())
|
||||
# Filter out "None" string values
|
||||
merged_options = {
|
||||
k: v for k, v in merged_options.items() if not isinstance(v, str) or v != "None"
|
||||
}
|
||||
|
||||
self._settings = merged_options
|
||||
self.set_model_name(merged_options.model)
|
||||
self.set_model_name(merged_options["model"])
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url or "api.cartesia.ai"
|
||||
self._receive_task = None
|
||||
@@ -254,7 +248,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = asyncio.create_task(self._receive_task_handler(self._report_error))
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
@@ -269,7 +263,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
return
|
||||
logger.debug("Connecting to Cartesia STT")
|
||||
|
||||
params = self._settings.to_dict()
|
||||
params = self._settings
|
||||
ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}"
|
||||
headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}
|
||||
|
||||
@@ -295,12 +289,15 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _process_messages(self):
|
||||
"""Process incoming WebSocket messages."""
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
@@ -349,6 +346,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
@@ -361,5 +359,6 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -554,7 +554,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self.add_word_timestamps(processed_timestamps)
|
||||
elif msg["type"] == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self.sample_rate,
|
||||
|
||||
@@ -160,7 +160,7 @@ def build_elevenlabs_voice_settings(
|
||||
class PronunciationDictionaryLocator(BaseModel):
|
||||
"""Locator for a pronunciation dictionary.
|
||||
|
||||
Attributes:
|
||||
Parameters:
|
||||
pronunciation_dictionary_id: The ID of the pronunciation dictionary.
|
||||
version_id: The version ID of the pronunciation dictionary.
|
||||
"""
|
||||
@@ -617,7 +617,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
await self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
@@ -731,10 +731,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
else:
|
||||
await self._send_text(text)
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield TTSStoppedFrame()
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -870,6 +868,11 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
def _set_voice_settings(self):
|
||||
return build_elevenlabs_voice_settings(self._settings)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
await super()._update_settings(settings)
|
||||
# Update voice settings for the next context creation
|
||||
self._voice_settings = self._set_voice_settings()
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset internal state variables."""
|
||||
self._cumulative_time = 0
|
||||
@@ -1044,7 +1047,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
|
||||
# Start TTS sequence if not already started
|
||||
if not self._started:
|
||||
self.start_word_timestamps()
|
||||
await self.start_word_timestamps()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
|
||||
@@ -19,11 +19,10 @@ from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat import __version__ as pipecat_version
|
||||
from pipecat import version as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -31,7 +30,7 @@ from pipecat.frames.frames import (
|
||||
TranslationFrame,
|
||||
)
|
||||
from pipecat.services.gladia.config import GladiaInputParams
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
@@ -176,7 +175,7 @@ class _InputParamsDescriptor:
|
||||
return GladiaInputParams
|
||||
|
||||
|
||||
class GladiaSTTService(STTService):
|
||||
class GladiaSTTService(WebsocketSTTService):
|
||||
"""Speech-to-Text service using Gladia's API.
|
||||
|
||||
This service connects to Gladia's WebSocket API for real-time transcription
|
||||
@@ -202,8 +201,6 @@ class GladiaSTTService(STTService):
|
||||
sample_rate: Optional[int] = None,
|
||||
model: str = "solaria-1",
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
max_reconnection_attempts: int = 5,
|
||||
reconnection_delay: float = 1.0,
|
||||
max_buffer_size: int = 1024 * 1024 * 20, # 20MB default buffer
|
||||
**kwargs,
|
||||
):
|
||||
@@ -222,8 +219,6 @@ class GladiaSTTService(STTService):
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
model: Model to use for transcription. Defaults to "solaria-1".
|
||||
params: Additional configuration parameters for Gladia service.
|
||||
max_reconnection_attempts: Maximum number of reconnection attempts. Defaults to 5.
|
||||
reconnection_delay: Initial delay between reconnection attempts in seconds.
|
||||
max_buffer_size: Maximum size of audio buffer in bytes. Defaults to 20MB.
|
||||
**kwargs: Additional arguments passed to the STTService parent class.
|
||||
"""
|
||||
@@ -256,16 +251,13 @@ class GladiaSTTService(STTService):
|
||||
self._url = url
|
||||
self.set_model_name(model)
|
||||
self._params = params
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._settings = {}
|
||||
|
||||
# Reconnection settings
|
||||
self._max_reconnection_attempts = max_reconnection_attempts
|
||||
self._reconnection_delay = reconnection_delay
|
||||
self._reconnection_attempts = 0
|
||||
# Session management
|
||||
self._session_url = None
|
||||
self._session_id = None
|
||||
self._connection_active = False
|
||||
|
||||
# Audio buffer management
|
||||
@@ -274,9 +266,8 @@ class GladiaSTTService(STTService):
|
||||
self._max_buffer_size = max_buffer_size
|
||||
self._buffer_lock = asyncio.Lock()
|
||||
|
||||
# Connection management
|
||||
self._connection_task = None
|
||||
self._should_reconnect = True
|
||||
def __str__(self):
|
||||
return f"{self.name} [{self._session_id}]"
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate performance metrics.
|
||||
@@ -308,7 +299,7 @@ class GladiaSTTService(STTService):
|
||||
|
||||
# Add custom_metadata if provided
|
||||
settings["custom_metadata"] = dict(self._params.custom_metadata or {})
|
||||
settings["custom_metadata"]["pipecat"] = pipecat_version
|
||||
settings["custom_metadata"]["pipecat"] = pipecat_version()
|
||||
|
||||
# Add endpointing parameters if provided
|
||||
if self._params.endpointing is not None:
|
||||
@@ -355,11 +346,7 @@ class GladiaSTTService(STTService):
|
||||
frame: The start frame triggering service startup.
|
||||
"""
|
||||
await super().start(frame)
|
||||
if self._connection_task:
|
||||
return
|
||||
|
||||
self._should_reconnect = True
|
||||
self._connection_task = self.create_task(self._connection_handler())
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Gladia STT websocket connection.
|
||||
@@ -368,14 +355,8 @@ class GladiaSTTService(STTService):
|
||||
frame: The end frame triggering service shutdown.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
self._should_reconnect = False
|
||||
await self._send_stop_recording()
|
||||
|
||||
if self._connection_task:
|
||||
await self.cancel_task(self._connection_task)
|
||||
self._connection_task = None
|
||||
|
||||
await self._cleanup_connection()
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Gladia STT websocket connection.
|
||||
@@ -384,13 +365,7 @@ class GladiaSTTService(STTService):
|
||||
frame: The cancel frame triggering service cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
self._should_reconnect = False
|
||||
|
||||
if self._connection_task:
|
||||
await self.cancel_task(self._connection_task)
|
||||
self._connection_task = None
|
||||
|
||||
await self._cleanup_connection()
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Run speech-to-text on audio data.
|
||||
@@ -412,88 +387,93 @@ class GladiaSTTService(STTService):
|
||||
trim_size = len(self._audio_buffer) - self._max_buffer_size
|
||||
self._audio_buffer = self._audio_buffer[trim_size:]
|
||||
self._bytes_sent = max(0, self._bytes_sent - trim_size)
|
||||
logger.warning(f"Audio buffer exceeded max size, trimmed {trim_size} bytes")
|
||||
logger.warning(f"{self} Audio buffer exceeded max size, trimmed {trim_size} bytes")
|
||||
|
||||
# Send audio if connected
|
||||
if self._connection_active and self._websocket and self._websocket.state is State.OPEN:
|
||||
try:
|
||||
await self._send_audio(audio)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"Websocket closed while sending audio chunk: {e}")
|
||||
logger.warning(f"{self} Websocket closed while sending audio chunk: {e}")
|
||||
self._connection_active = False
|
||||
|
||||
yield None
|
||||
|
||||
async def _connection_handler(self):
|
||||
"""Handle WebSocket connection with automatic reconnection."""
|
||||
while self._should_reconnect:
|
||||
try:
|
||||
# Initialize session if needed
|
||||
if not self._session_url:
|
||||
settings = self._prepare_settings()
|
||||
response = await self._setup_gladia(settings)
|
||||
self._session_url = response["url"]
|
||||
self._reconnection_attempts = 0
|
||||
logger.info(f"Session URL : {self._session_url}")
|
||||
async def _connect(self):
|
||||
"""Connect to the Gladia service.
|
||||
|
||||
# Connect with automatic reconnection
|
||||
async with websocket_connect(self._session_url) as websocket:
|
||||
try:
|
||||
self._websocket = websocket
|
||||
self._connection_active = True
|
||||
logger.debug(f"{self} Connected to Gladia WebSocket")
|
||||
Initializes the session if needed and establishes websocket connection.
|
||||
"""
|
||||
# Initialize session if needed
|
||||
if not self._session_url:
|
||||
settings = self._prepare_settings()
|
||||
response = await self._setup_gladia(settings)
|
||||
self._session_url = response["url"]
|
||||
self._session_id = response["id"]
|
||||
logger.info(f"{self} Session URL: {self._session_url}")
|
||||
|
||||
# Send buffered audio if any
|
||||
await self._send_buffered_audio()
|
||||
await self._connect_websocket()
|
||||
|
||||
# Start tasks
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(self._receive_task, self._keepalive_task)
|
||||
if self._websocket and not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"WebSocket connection closed: {e}")
|
||||
self._connection_active = False
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the Gladia service.
|
||||
|
||||
# Clean up tasks
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Attempt reconnect using helper
|
||||
if not await self._maybe_reconnect():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._connection_active = False
|
||||
|
||||
if not self._should_reconnect:
|
||||
break
|
||||
|
||||
# Reset session URL to get a new one
|
||||
self._session_url = None
|
||||
await asyncio.sleep(self._reconnection_delay)
|
||||
|
||||
async def _cleanup_connection(self):
|
||||
"""Clean up connection resources."""
|
||||
Cleans up tasks and closes websocket connection.
|
||||
"""
|
||||
self._connection_active = False
|
||||
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish the websocket connection to Gladia."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug(f"{self}Connecting to Gladia WebSocket")
|
||||
|
||||
self._websocket = await websocket_connect(self._session_url)
|
||||
self._connection_active = True
|
||||
|
||||
# Reset byte tracking for new connection
|
||||
async with self._buffer_lock:
|
||||
self._bytes_sent = 0
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
# Send buffered audio if any
|
||||
await self._send_buffered_audio()
|
||||
|
||||
logger.debug(f"{self} Connected to Gladia WebSocket")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unable to connect to Gladia: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to Gladia."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
logger.debug(f"{self} Disconnecting from Gladia WebSocket")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _setup_gladia(self, settings: Dict[str, Any]):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
params = {}
|
||||
@@ -510,10 +490,10 @@ class GladiaSTTService(STTService):
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"Gladia error: {response.status}: {error_text or response.reason}"
|
||||
f"{self} Gladia error: {response.status}: {error_text or response.reason}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to initialize Gladia session: {response.status} - {error_text}"
|
||||
f"{self} Failed to initialize Gladia session: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
@@ -541,28 +521,26 @@ class GladiaSTTService(STTService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(json.dumps({"type": "stop_recording"}))
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic empty audio chunks to keep the connection alive."""
|
||||
try:
|
||||
KEEPALIVE_SLEEP = 20
|
||||
while self._connection_active:
|
||||
# Send keepalive (Gladia times out after 30 seconds)
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
# Send an empty audio chunk as keepalive
|
||||
empty_audio = b""
|
||||
await self._send_audio(empty_audio)
|
||||
else:
|
||||
logger.debug("Websocket closed, stopping keepalive")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.debug("Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
def _get_websocket(self):
|
||||
"""Get the current WebSocket connection.
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
Returns:
|
||||
The WebSocket connection.
|
||||
|
||||
Raises:
|
||||
Exception: If WebSocket is not connected.
|
||||
"""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process websocket messages.
|
||||
|
||||
Continuously processes messages from the websocket connection.
|
||||
"""
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
content = json.loads(message)
|
||||
|
||||
# Handle audio chunk acknowledgments
|
||||
@@ -617,26 +595,24 @@ class GladiaSTTService(STTService):
|
||||
translation, "", time_now_iso8601(), translated_language
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"{self} Received non-JSON message: {message}")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic empty audio chunks to keep the connection alive."""
|
||||
try:
|
||||
KEEPALIVE_SLEEP = 20
|
||||
while self._connection_active:
|
||||
# Send keepalive (Gladia times out after 30 seconds)
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
# Send an empty audio chunk as keepalive
|
||||
empty_audio = b""
|
||||
await self._send_audio(empty_audio)
|
||||
else:
|
||||
logger.debug(f"{self} Websocket closed, stopping keepalive")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
# Expected when closing the connection
|
||||
pass
|
||||
logger.debug(f"{self} Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
|
||||
async def _maybe_reconnect(self) -> bool:
|
||||
"""Handle exponential backoff reconnection logic."""
|
||||
if not self._should_reconnect:
|
||||
return False
|
||||
self._reconnection_attempts += 1
|
||||
if self._reconnection_attempts > self._max_reconnection_attempts:
|
||||
await self.push_error(
|
||||
error_msg=f"Max reconnection attempts ({self._max_reconnection_attempts}) reached",
|
||||
)
|
||||
self._should_reconnect = False
|
||||
return False
|
||||
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
|
||||
logger.debug(
|
||||
f"{self} Reconnecting in {delay} seconds (attempt {self._reconnection_attempts}/{self._max_reconnection_attempts})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
return True
|
||||
|
||||
@@ -68,6 +68,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame, LLMSearchResult
|
||||
from pipecat.services.google.utils import update_google_client_http_options
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
@@ -562,18 +563,18 @@ class InputParams(BaseModel):
|
||||
context_window_compression: Context compression settings. Defaults to None.
|
||||
thinking: Thinking settings. Defaults to None.
|
||||
Note that these settings may require specifying a model that
|
||||
supports them, e.g. "gemini-2.5-flash-native-audio-preview-09-2025".
|
||||
supports them, e.g. "gemini-2.5-flash-native-audio-preview-12-2025".
|
||||
enable_affective_dialog: Enable affective dialog, which allows Gemini
|
||||
to adapt to expression and tone. Defaults to None.
|
||||
Note that these settings may require specifying a model that
|
||||
supports them, e.g. "gemini-2.5-flash-native-audio-preview-09-2025".
|
||||
supports them, e.g. "gemini-2.5-flash-native-audio-preview-12-2025".
|
||||
Also note that this setting may require specifying an API version that
|
||||
supports it, e.g. HttpOptions(api_version="v1alpha").
|
||||
proactivity: Proactivity settings, which allows Gemini to proactively
|
||||
decide how to behave, such as whether to avoid responding to
|
||||
content that is not relevant. Defaults to None.
|
||||
Note that these settings may require specifying a model that
|
||||
supports them, e.g. "gemini-2.5-flash-native-audio-preview-09-2025".
|
||||
supports them, e.g. "gemini-2.5-flash-native-audio-preview-12-2025".
|
||||
Also note that this setting may require specifying an API version that
|
||||
supports it, e.g. HttpOptions(api_version="v1alpha").
|
||||
extra: Additional parameters. Defaults to empty dict.
|
||||
@@ -614,7 +615,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
model="models/gemini-2.0-flash-live-001",
|
||||
model="models/gemini-2.5-flash-native-audio-preview-12-2025",
|
||||
voice_id: str = "Charon",
|
||||
start_audio_paused: bool = False,
|
||||
start_video_paused: bool = False,
|
||||
@@ -637,7 +638,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
Please use `http_options` to customize requests made by the
|
||||
API client.
|
||||
|
||||
model: Model identifier to use. Defaults to "models/gemini-2.0-flash-live-001".
|
||||
model: Model identifier to use. Defaults to "models/gemini-2.5-flash-native-audio-preview-12-2025".
|
||||
voice_id: TTS voice identifier. Defaults to "Charon".
|
||||
start_audio_paused: Whether to start with audio input paused. Defaults to False.
|
||||
start_video_paused: Whether to start with video input paused. Defaults to False.
|
||||
@@ -681,7 +682,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._video_input_paused = start_video_paused
|
||||
self._context = None
|
||||
self._api_key = api_key
|
||||
self._http_options = http_options
|
||||
self._http_options = update_google_client_http_options(http_options)
|
||||
self._session: AsyncSession = None
|
||||
self._connection_task = None
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class GeminiLiveVertexLLMService(GeminiLiveLLMService):
|
||||
credentials_path: Optional[str] = None,
|
||||
location: str,
|
||||
project_id: str,
|
||||
model="google/gemini-2.0-flash-live-preview-04-09",
|
||||
model="google/gemini-live-2.5-flash-native-audio",
|
||||
voice_id: str = "Charon",
|
||||
start_audio_paused: bool = False,
|
||||
start_video_paused: bool = False,
|
||||
@@ -70,7 +70,7 @@ class GeminiLiveVertexLLMService(GeminiLiveLLMService):
|
||||
credentials_path: Path to the service account JSON file.
|
||||
location: GCP region for Vertex AI endpoint (e.g., "us-east4").
|
||||
project_id: Google Cloud project ID.
|
||||
model: Model identifier to use. Defaults to "models/gemini-2.0-flash-live-preview-04-09".
|
||||
model: Model identifier to use. Defaults to "models/gemini-live-2.5-flash-native-audio".
|
||||
voice_id: TTS voice identifier. Defaults to "Charon".
|
||||
start_audio_paused: Whether to start with audio input paused. Defaults to False.
|
||||
start_video_paused: Whether to start with video input paused. Defaults to False.
|
||||
@@ -126,6 +126,7 @@ class GeminiLiveVertexLLMService(GeminiLiveLLMService):
|
||||
credentials=self._credentials,
|
||||
project=self._project_id,
|
||||
location=self._location,
|
||||
http_options=self._http_options,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -16,13 +16,14 @@ import os
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
||||
from pipecat.services.google.utils import update_google_client_http_options
|
||||
from pipecat.services.image_service import ImageGenService
|
||||
|
||||
try:
|
||||
@@ -60,6 +61,7 @@ class GoogleImageGenService(ImageGenService):
|
||||
*,
|
||||
api_key: str,
|
||||
params: Optional[InputParams] = None,
|
||||
http_options: Optional[Any] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the GoogleImageGenService with API key and parameters.
|
||||
@@ -67,11 +69,16 @@ class GoogleImageGenService(ImageGenService):
|
||||
Args:
|
||||
api_key: Google AI API key for authentication.
|
||||
params: Configuration parameters for image generation. Defaults to InputParams().
|
||||
http_options: HTTP options for the client.
|
||||
**kwargs: Additional arguments passed to the parent ImageGenService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._params = params or GoogleImageGenService.InputParams()
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
|
||||
# Add client header
|
||||
http_options = update_google_client_http_options(http_options)
|
||||
|
||||
self._client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
self.set_model_name(self._params.model)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
@@ -24,6 +24,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter, GeminiLLMInvocationParams
|
||||
from pipecat.frames.frames import (
|
||||
AssistantImageRawFrame,
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
@@ -32,8 +33,12 @@ from pipecat.frames.frames import (
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
OutputImageRawFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -50,6 +55,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.google.utils import update_google_client_http_options
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
@@ -473,11 +479,16 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
if c["type"] == "text":
|
||||
parts.append(Part(text=c["text"]))
|
||||
elif c["type"] == "image_url":
|
||||
# Extract MIME type from data URL (format: "data:image/jpeg;base64,...")
|
||||
url = c["image_url"]["url"]
|
||||
mime_type = (
|
||||
url.split(":")[1].split(";")[0] if url.startswith("data:") else "image/jpeg"
|
||||
)
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="image/jpeg",
|
||||
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
|
||||
mime_type=mime_type,
|
||||
data=base64.b64decode(url.split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -665,6 +676,34 @@ class GoogleLLMService(LLMService):
|
||||
# Overriding the default adapter to use the Gemini one.
|
||||
adapter_class = GeminiLLMAdapter
|
||||
|
||||
class ThinkingConfig(BaseModel):
|
||||
"""Configuration for controlling the model's internal "thinking" process used before generating a response.
|
||||
|
||||
Gemini 2.5 and 3 series models have this thinking process.
|
||||
|
||||
Parameters:
|
||||
thinking_level: Thinking level for Gemini 3 Pro. Can be "low" or "high".
|
||||
If not provided, Gemini 3 Pro defaults to "high".
|
||||
Note: Gemini 2.5 series should use thinking_budget instead.
|
||||
thinking_budget: Token budget for thinking, for Gemini 2.5 series.
|
||||
-1 for dynamic thinking (model decides), 0 to disable thinking,
|
||||
or a specific token count (e.g., 128-32768 for 2.5 Pro).
|
||||
If not provided, most models today default to dynamic thinking.
|
||||
See https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
||||
for default values and allowed ranges.
|
||||
Note: Gemini 3 Pro should use thinking_level instead.
|
||||
include_thoughts: Whether to include thought summaries in the response.
|
||||
Today's models default to not including thoughts (False).
|
||||
"""
|
||||
|
||||
thinking_budget: Optional[int] = Field(default=None)
|
||||
|
||||
# Why `| str` here? To not break compatibility in case Google adds more
|
||||
# levels in the future.
|
||||
thinking_level: Optional[Literal["low", "high"] | str] = Field(default=None)
|
||||
|
||||
include_thoughts: Optional[bool] = Field(default=None)
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Google AI models.
|
||||
|
||||
@@ -673,6 +712,12 @@ class GoogleLLMService(LLMService):
|
||||
temperature: Sampling temperature between 0.0 and 2.0.
|
||||
top_k: Top-k sampling parameter.
|
||||
top_p: Top-p sampling parameter between 0.0 and 1.0.
|
||||
thinking: Thinking configuration with thinking_budget, thinking_level, and include_thoughts.
|
||||
Used to control the model's internal "thinking" process used before generating a response.
|
||||
Gemini 2.5 series models use thinking_budget; Gemini 3 models use thinking_level.
|
||||
If this is not provided, Pipecat disables thinking for all
|
||||
models where that's possible (the 2.5 series, except 2.5 Pro),
|
||||
to reduce latency.
|
||||
extra: Additional parameters as a dictionary.
|
||||
"""
|
||||
|
||||
@@ -680,6 +725,7 @@ class GoogleLLMService(LLMService):
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
thinking: Optional["GoogleLLMService.ThinkingConfig"] = Field(default=None)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
@@ -713,13 +759,14 @@ class GoogleLLMService(LLMService):
|
||||
self.set_model_name(model)
|
||||
self._api_key = api_key
|
||||
self._system_instruction = system_instruction
|
||||
self._http_options = http_options
|
||||
self._http_options = update_google_client_http_options(http_options)
|
||||
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
"thinking": params.thinking,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self._tools = tools
|
||||
@@ -751,17 +798,25 @@ class GoogleLLMService(LLMService):
|
||||
"""
|
||||
messages = []
|
||||
system = []
|
||||
tools = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system_instruction"]
|
||||
tools = params["tools"]
|
||||
else:
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system_message", None)
|
||||
tools = context.tools or []
|
||||
|
||||
generation_config = GenerateContentConfig(system_instruction=system)
|
||||
# Build generation config using the same method as streaming
|
||||
generation_params = self._build_generation_params(
|
||||
system_instruction=system, tools=tools if tools else None
|
||||
)
|
||||
|
||||
generation_config = GenerateContentConfig(**generation_params)
|
||||
|
||||
# Use the new google-genai client's async method
|
||||
response = await self._client.aio.models.generate_content(
|
||||
@@ -778,6 +833,48 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
return None
|
||||
|
||||
def _build_generation_params(
|
||||
self,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[List] = None,
|
||||
tool_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build generation parameters for Google AI API.
|
||||
|
||||
Args:
|
||||
system_instruction: Optional system instruction to use.
|
||||
tools: Optional list of tools to include.
|
||||
tool_config: Optional tool configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary of generation parameters with None values filtered out.
|
||||
"""
|
||||
# Filter out None values and create GenerationContentConfig
|
||||
generation_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"system_instruction": system_instruction,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"max_output_tokens": self._settings["max_tokens"],
|
||||
"tools": tools,
|
||||
"tool_config": tool_config,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Add thinking parameters if configured
|
||||
if self._settings["thinking"]:
|
||||
generation_params["thinking_config"] = self._settings["thinking"].model_dump(
|
||||
exclude_unset=True
|
||||
)
|
||||
|
||||
if self._settings["extra"]:
|
||||
generation_params.update(self._settings["extra"])
|
||||
|
||||
return generation_params
|
||||
|
||||
def _maybe_unset_thinking_budget(self, generation_params: Dict[str, Any]):
|
||||
try:
|
||||
# There's no way to introspect on model capabilities, so
|
||||
@@ -815,30 +912,15 @@ class GoogleLLMService(LLMService):
|
||||
if self._tool_config:
|
||||
tool_config = self._tool_config
|
||||
|
||||
# Filter out None values and create GenerationContentConfig
|
||||
generation_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"system_instruction": self._system_instruction,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"max_output_tokens": self._settings["max_tokens"],
|
||||
"tools": tools,
|
||||
"tool_config": tool_config,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
if self._settings["extra"]:
|
||||
generation_params.update(self._settings["extra"])
|
||||
# Build generation parameters
|
||||
generation_params = self._build_generation_params(
|
||||
system_instruction=self._system_instruction, tools=tools, tool_config=tool_config
|
||||
)
|
||||
|
||||
# possibly modify generation_params (in place) to set thinking to off by default
|
||||
self._maybe_unset_thinking_budget(generation_params)
|
||||
|
||||
generation_config = (
|
||||
GenerateContentConfig(**generation_params) if generation_params else None
|
||||
)
|
||||
generation_config = GenerateContentConfig(**generation_params)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
return await self._client.aio.models.generate_content_stream(
|
||||
@@ -885,7 +967,7 @@ class GoogleLLMService(LLMService):
|
||||
reasoning_tokens = 0
|
||||
|
||||
grounding_metadata = None
|
||||
search_result = ""
|
||||
accumulated_text = ""
|
||||
|
||||
try:
|
||||
# Generate content using either OpenAILLMContext or universal LLMContext
|
||||
@@ -918,27 +1000,91 @@ class GoogleLLMService(LLMService):
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if not part.thought and part.text:
|
||||
search_result += part.text
|
||||
await self.push_frame(LLMTextFrame(part.text))
|
||||
function_call_id = None
|
||||
if part.text:
|
||||
if part.thought:
|
||||
# Gemini emits fully-formed thoughts rather
|
||||
# than chunks so bracket each thought in
|
||||
# start/end
|
||||
await self.push_frame(LLMThoughtStartFrame())
|
||||
await self.push_frame(LLMThoughtTextFrame(part.text))
|
||||
await self.push_frame(LLMThoughtEndFrame())
|
||||
else:
|
||||
accumulated_text += part.text
|
||||
await self.push_frame(LLMTextFrame(part.text))
|
||||
elif part.function_call:
|
||||
function_call = part.function_call
|
||||
id = function_call.id or str(uuid.uuid4())
|
||||
logger.debug(f"Function call: {function_call.name}:{id}")
|
||||
function_call_id = function_call.id or str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"Function call: {function_call.name}:{function_call_id}"
|
||||
)
|
||||
function_calls.append(
|
||||
FunctionCallFromLLM(
|
||||
context=context,
|
||||
tool_call_id=id,
|
||||
tool_call_id=function_call_id,
|
||||
function_name=function_call.name,
|
||||
arguments=function_call.args or {},
|
||||
)
|
||||
)
|
||||
elif part.inline_data and part.inline_data.data:
|
||||
# Here we assume that inline_data is an image.
|
||||
image = Image.open(io.BytesIO(part.inline_data.data))
|
||||
frame = OutputImageRawFrame(
|
||||
image=image.tobytes(), size=image.size, format="RGB"
|
||||
await self.push_frame(
|
||||
AssistantImageRawFrame(
|
||||
image=image.tobytes(),
|
||||
size=image.size,
|
||||
format="RGB",
|
||||
original_data=part.inline_data.data,
|
||||
original_mime_type=part.inline_data.mime_type,
|
||||
)
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Handle Gemini thought signatures.
|
||||
#
|
||||
# - Gemini 2.5: they appear on function_call Parts,
|
||||
# and then (surprisingly) on the last(*) Part of
|
||||
# model responses following the first function_call
|
||||
# in a conversation.
|
||||
# - Gemini 3 Pro: they appear on the last(*) Part
|
||||
# of model responses, regardless of Part type.
|
||||
#
|
||||
# (*) Since we're using the streaming API, though,
|
||||
# where text Parts may be split across multiple
|
||||
# chunks (each represented by a Part, confusingly),
|
||||
# signatures may actually appear with the first
|
||||
# chunk (Gemini 2.5) or in a trailing empty-text
|
||||
# chunk (Gemini 3 Pro).
|
||||
if part.thought_signature:
|
||||
# Save a "bookmark" for the signature, so we
|
||||
# can later be sure we've put it in the right
|
||||
# place in context when sending the context
|
||||
# back to the LLM to continue the conversation.
|
||||
bookmark = {}
|
||||
if part.function_call:
|
||||
bookmark["function_call"] = function_call_id
|
||||
elif part.inline_data and part.inline_data.data:
|
||||
bookmark["inline_data"] = part.inline_data
|
||||
elif part.text is not None:
|
||||
# Account for Gemini 3 Pro trailing
|
||||
# empty-text chunk by using all the text
|
||||
# seen so far in this response's chunks.
|
||||
bookmark["text"] = accumulated_text
|
||||
else:
|
||||
logger.warning("Thought signature found on unhandled Part type")
|
||||
if bookmark:
|
||||
await self.push_frame(
|
||||
LLMMessagesAppendFrame(
|
||||
[
|
||||
self.get_llm_adapter().create_llm_specific_message(
|
||||
{
|
||||
"type": "thought_signature",
|
||||
"signature": part.thought_signature,
|
||||
"bookmark": bookmark,
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
candidate.grounding_metadata
|
||||
@@ -987,7 +1133,7 @@ class GoogleLLMService(LLMService):
|
||||
finally:
|
||||
if grounding_metadata and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
search_result=search_result,
|
||||
search_result=accumulated_text,
|
||||
origins=grounding_metadata["origins"],
|
||||
rendered_content=grounding_metadata["rendered_content"],
|
||||
)
|
||||
@@ -1049,6 +1195,14 @@ class GoogleLLMService(LLMService):
|
||||
# Do nothing - we're shutting down anyway
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings):
|
||||
"""Override to handle ThinkingConfig validation."""
|
||||
# Convert thinking dict to ThinkingConfig if needed
|
||||
if "thinking" in settings and isinstance(settings["thinking"], dict):
|
||||
settings = dict(settings) # Make a copy to avoid modifying the original
|
||||
settings["thinking"] = self.ThinkingConfig(**settings["thinking"])
|
||||
await super()._update_settings(settings)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
|
||||
43
src/pipecat/services/google/utils.py
Normal file
43
src/pipecat/services/google/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Utility functions for Google services."""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pipecat import version as pipecat_version
|
||||
|
||||
|
||||
def update_google_client_http_options(http_options: Optional[Union[Dict[str, Any], Any]]) -> Any:
|
||||
"""Updates http_options with the x-goog-api-client header.
|
||||
|
||||
Args:
|
||||
http_options: The existing http_options, which can be None, a dictionary,
|
||||
or an object with a 'headers' attribute.
|
||||
|
||||
Returns:
|
||||
The updated http_options.
|
||||
"""
|
||||
client_header = {"x-goog-api-client": f"pipecat/{pipecat_version()}"}
|
||||
|
||||
if http_options is None:
|
||||
http_options = {"headers": client_header}
|
||||
elif isinstance(http_options, dict):
|
||||
# Create a copy to avoid modifying the original dictionary if it's reused elsewhere
|
||||
http_options = http_options.copy()
|
||||
if "headers" in http_options:
|
||||
http_options["headers"].update(client_header)
|
||||
else:
|
||||
http_options["headers"] = client_header
|
||||
elif hasattr(http_options, "headers"):
|
||||
# We can't easily copy an arbitrary object, so we modify it in place.
|
||||
# This assumes the object is mutable and it's safe to do so.
|
||||
if http_options.headers is None:
|
||||
http_options.headers = client_header
|
||||
else:
|
||||
http_options.headers.update(client_header)
|
||||
|
||||
return http_options
|
||||
5
src/pipecat/services/gradium/__init__.py
Normal file
5
src/pipecat/services/gradium/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
239
src/pipecat/services/gradium/stt.py
Normal file
239
src/pipecat/services/gradium/stt.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gradium's speech-to-text service implementation.
|
||||
|
||||
This module provides integration with Gradium's real-time speech-to-text
|
||||
WebSocket API for streaming audio transcription.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
"""Gradium real-time speech-to-text service.
|
||||
|
||||
Provides real-time speech transcription using Gradium's WebSocket API.
|
||||
Supports both interim and final transcriptions with configurable parameters
|
||||
for audio processing and connection management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
json_config: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._json_config = json_config
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
self._audio_buffer = bytearray()
|
||||
self._chunk_size_ms = 80
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True if metrics generation is supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Start frame to begin processing.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: End frame to stop processing.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Cancel frame to abort processing.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text conversion.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to process.
|
||||
|
||||
Yields:
|
||||
None (processing handled via WebSocket messages).
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
chunk = base64.b64encode(chunk).decode("utf-8")
|
||||
msg = {"type": "audio", "audio": chunk}
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
yield None
|
||||
|
||||
@traced_stt
|
||||
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
|
||||
"""Record transcription event for tracing."""
|
||||
pass
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
ws_url = self._api_endpoint_base_url
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
"x-api-source": "pipecat",
|
||||
}
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"input_format": "pcm",
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Gradium STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
await self._process_messages()
|
||||
logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting")
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _process_response(self, msg):
|
||||
type_ = msg.get("type", "")
|
||||
if type_ == "text":
|
||||
await self._handle_text(msg["text"])
|
||||
elif type_ == "end_of_stream":
|
||||
await self._handle_end_of_stream()
|
||||
elif type_ == "error":
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
|
||||
async def _handle_end_of_stream(self):
|
||||
"""Handle termination message."""
|
||||
logger.debug("Received end_of_stream message from server")
|
||||
|
||||
async def _handle_text(self, text: str):
|
||||
"""Handle transcription results."""
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
315
src/pipecat/services/gradium/tts.py
Normal file
315
src/pipecat/services/gradium/tts.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
|
||||
"""Gradium Text-to-Speech service implementation."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets import ConnectionClosedOK
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Gradium, you need to `pip install pipecat-ai[gradium]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 48000
|
||||
|
||||
|
||||
class GradiumTTSService(InterruptibleWordTTSService):
|
||||
"""Text-to-Speech service using Gradium's websocket API."""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Gradium TTS service.
|
||||
|
||||
Parameters:
|
||||
temp: Temperature to be used for generation, defaults to 0.6.
|
||||
"""
|
||||
|
||||
temp: Optional[float] = 0.6
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "YTpq7expH9539ERJ",
|
||||
url: str = "wss://eu.api.gradium.ai/api/speech/tts",
|
||||
model: str = "default",
|
||||
json_config: Optional[str] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
voice_id: the voice identifier.
|
||||
url: Gradium websocket API endpoint.
|
||||
model: Model ID to use for synthesis.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
params: Additional configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = params or GradiumTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._json_config = json_config
|
||||
self._model = model
|
||||
self._settings = {
|
||||
"voice_id": voice_id,
|
||||
"model_name": model,
|
||||
"output_format": "pcm",
|
||||
}
|
||||
|
||||
# State tracking
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Gradium service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Update the TTS model.
|
||||
|
||||
Args:
|
||||
model: The model name to use for synthesis.
|
||||
"""
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
self._settings["voice_id"] = self._voice_id
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Gradium API."""
|
||||
return {"text": text, "type": "text"}
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish websocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close connection.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel current operation and clean up.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish websocket connection and start receive task."""
|
||||
logger.debug(f"{self}: connecting")
|
||||
|
||||
# If the server disconnected, cancel the receive-task so that it can be reset below.
|
||||
if self._websocket is None or self._websocket.state is not State.OPEN:
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
logger.debug(f"{self}: setting receive task")
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
logger.debug(f"{self}: disconnecting")
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Gradium websocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
headers = {"x-api-key": self._api_key, "x-api-source": "pipecat"}
|
||||
self._websocket = await websocket_connect(self._url, additional_headers=headers)
|
||||
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"output_format": "pcm",
|
||||
"voice_id": self._voice_id,
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close websocket connection and reset state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
# TODO(laurent): This should not be necessary as it should happen when
|
||||
# receiving the messages but this does not seem to always be the case
|
||||
# and that may lead to a busy polling loop.
|
||||
if self._websocket and self._websocket.state is State.CLOSED:
|
||||
raise ConnectionClosedOK(None, None)
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
if msg["type"] == "audio":
|
||||
# Process audio chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
elif msg["type"] == "text":
|
||||
await self.add_word_timestamps([(msg["text"], msg["start_s"])])
|
||||
elif msg["type"] == "end_of_stream":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
|
||||
elif msg["type"] == "error":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Gradium's streaming API.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
_state = self._websocket.state if self._websocket is not None else None
|
||||
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
self._websocket = None
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
yield TTSStartedFrame()
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -16,6 +16,8 @@ import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.services.heygen.base_api import BaseAvatarApi, StandardSessionResponse
|
||||
|
||||
|
||||
class AvatarQuality(str, Enum):
|
||||
"""Enum representing different avatar quality levels."""
|
||||
@@ -136,7 +138,7 @@ class HeygenApiError(Exception):
|
||||
self.response_text = response_text
|
||||
|
||||
|
||||
class HeyGenApi:
|
||||
class HeyGenApi(BaseAvatarApi):
|
||||
"""HeyGen Streaming API client."""
|
||||
|
||||
BASE_URL = "https://api.heygen.com/v1"
|
||||
@@ -193,8 +195,8 @@ class HeyGenApi:
|
||||
logger.error(f"Network error while calling HeyGen API: {str(e)}")
|
||||
raise
|
||||
|
||||
async def new_session(self, request_data: NewSessionRequest) -> HeyGenSession:
|
||||
"""Create a new streaming session.
|
||||
async def new_session(self, request_data: NewSessionRequest) -> StandardSessionResponse:
|
||||
"""Create a new streaming session and start it immediately.
|
||||
|
||||
https://docs.heygen.com/reference/new-session
|
||||
|
||||
@@ -202,7 +204,7 @@ class HeyGenApi:
|
||||
request_data: Session configuration parameters.
|
||||
|
||||
Returns:
|
||||
Session information, including ID and access token.
|
||||
StandardSessionResponse: Standardized session information with HeyGen raw response.
|
||||
"""
|
||||
params = {
|
||||
"quality": request_data.quality,
|
||||
@@ -225,9 +227,21 @@ class HeyGenApi:
|
||||
session_info = await self._request("/streaming.new", params)
|
||||
print("heygen session info", session_info)
|
||||
|
||||
return HeyGenSession.model_validate(session_info)
|
||||
heygen_session = HeyGenSession.model_validate(session_info)
|
||||
|
||||
async def start_session(self, session_id: str) -> Any:
|
||||
await self._start_session(heygen_session.session_id)
|
||||
|
||||
# Convert to standardized response
|
||||
return StandardSessionResponse(
|
||||
session_id=heygen_session.session_id,
|
||||
access_token=heygen_session.access_token,
|
||||
livekit_url=heygen_session.url,
|
||||
livekit_agent_token=heygen_session.livekit_agent_token,
|
||||
ws_url=heygen_session.realtime_endpoint,
|
||||
raw_response=heygen_session,
|
||||
)
|
||||
|
||||
async def _start_session(self, session_id: str) -> Any:
|
||||
"""Start the streaming session.
|
||||
|
||||
https://docs.heygen.com/reference/start-session
|
||||
339
src/pipecat/services/heygen/api_liveavatar.py
Normal file
339
src/pipecat/services/heygen/api_liveavatar.py
Normal file
@@ -0,0 +1,339 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LiveAvatar API.
|
||||
|
||||
API to communicate with LiveAvatar Streaming API.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.services.heygen.base_api import BaseAvatarApi, StandardSessionResponse
|
||||
|
||||
|
||||
class AvatarPersona(BaseModel):
|
||||
"""Avatar persona settings for LiveAvatar.
|
||||
|
||||
Parameters:
|
||||
voice_id (Optional[str]): ID of the voice to be used.
|
||||
context_id (Optional[str]): Context ID for the avatar.
|
||||
language (str): Language code for the avatar (default: "en").
|
||||
"""
|
||||
|
||||
voice_id: Optional[str] = None
|
||||
context_id: Optional[str] = None
|
||||
language: str = "en"
|
||||
|
||||
|
||||
class CustomSDKLiveKitConfig(BaseModel):
|
||||
"""Custom LiveKit configuration.
|
||||
|
||||
Parameters:
|
||||
livekit_url (str): LiveKit server URL.
|
||||
livekit_room (str): LiveKit room name.
|
||||
livekit_client_token (str): LiveKit client access token.
|
||||
"""
|
||||
|
||||
livekit_url: str
|
||||
livekit_room: str
|
||||
livekit_client_token: str
|
||||
|
||||
|
||||
class LiveAvatarNewSessionRequest(BaseModel):
|
||||
"""Request model for creating a LiveAvatar session token.
|
||||
|
||||
Parameters:
|
||||
mode (str): Session mode (default: "CUSTOM").
|
||||
avatar_id (str): Unique identifier for the avatar.
|
||||
avatar_persona (AvatarPersona): Avatar persona configuration.
|
||||
"""
|
||||
|
||||
mode: str = "CUSTOM"
|
||||
avatar_id: str
|
||||
avatar_persona: Optional[AvatarPersona] = None
|
||||
livekit_config: Optional[CustomSDKLiveKitConfig] = None
|
||||
|
||||
|
||||
class SessionTokenData(BaseModel):
|
||||
"""Data model for session token response.
|
||||
|
||||
Parameters:
|
||||
session_id (str): Unique identifier for the session.
|
||||
session_token (str): Session token for authentication.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
session_token: str
|
||||
|
||||
|
||||
class SessionTokenResponse(BaseModel):
|
||||
"""Response model for LiveAvatar session token.
|
||||
|
||||
Parameters:
|
||||
code (int): Response status code.
|
||||
data (SessionTokenData): Session token data containing session_id and session_token.
|
||||
message (str): Response message.
|
||||
"""
|
||||
|
||||
code: int
|
||||
data: SessionTokenData
|
||||
message: str
|
||||
|
||||
|
||||
class LiveAvatarSessionData(BaseModel):
|
||||
"""Data model for LiveAvatar session response.
|
||||
|
||||
Parameters:
|
||||
session_id (str): Unique identifier for the streaming session.
|
||||
livekit_url (str): LiveKit server URL for the session.
|
||||
livekit_client_token (str): Access token for LiveKit.
|
||||
max_session_duration (int): Maximum session duration in seconds.
|
||||
ws_url (str): WebSocket URL for the session.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
livekit_url: str
|
||||
livekit_client_token: str
|
||||
max_session_duration: int
|
||||
ws_url: str
|
||||
|
||||
|
||||
class LiveAvatarSessionResponse(BaseModel):
|
||||
"""Response model for LiveAvatar session start.
|
||||
|
||||
Parameters:
|
||||
code (int): Response status code.
|
||||
data (LiveAvatarSessionData): Session data containing connection details.
|
||||
message (str): Response message.
|
||||
"""
|
||||
|
||||
code: int
|
||||
data: LiveAvatarSessionData
|
||||
message: str
|
||||
|
||||
|
||||
class LiveAvatarApiError(Exception):
|
||||
"""Custom exception for LiveAvatar API errors."""
|
||||
|
||||
def __init__(self, message: str, status: int, response_text: str) -> None:
|
||||
"""Initialize the LiveAvatar API error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
status: HTTP status code
|
||||
response_text: Raw response text from the API
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.response_text = response_text
|
||||
|
||||
|
||||
class LiveAvatarApi(BaseAvatarApi):
|
||||
"""LiveAvatar Streaming API client."""
|
||||
|
||||
BASE_URL = "https://api.liveavatar.com/v1"
|
||||
|
||||
def __init__(self, api_key: str, session: aiohttp.ClientSession) -> None:
|
||||
"""Initialize the LiveAvatar API.
|
||||
|
||||
Args:
|
||||
api_key: LiveAvatar API key
|
||||
session: aiohttp client session
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._session = session
|
||||
self._session_token = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
bearer_token: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Make a request to the LiveAvatar API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.).
|
||||
path: API endpoint path.
|
||||
params: JSON-serializable parameters.
|
||||
bearer_token: Optional bearer token for authorization.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response data.
|
||||
|
||||
Raises:
|
||||
LiveAvatarApiError: If the API response is not successful.
|
||||
aiohttp.ClientError: For network-related errors.
|
||||
"""
|
||||
url = f"{self.BASE_URL}{path}"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
}
|
||||
|
||||
if bearer_token:
|
||||
headers["authorization"] = f"Bearer {bearer_token}"
|
||||
else:
|
||||
headers["X-API-KEY"] = self._api_key
|
||||
|
||||
if params is not None:
|
||||
headers["content-type"] = "application/json"
|
||||
|
||||
logger.debug(f"LiveAvatar API request: {method} {url}")
|
||||
|
||||
try:
|
||||
async with self._session.request(method, url, json=params, headers=headers) as response:
|
||||
if not response.ok:
|
||||
response_text = await response.text()
|
||||
logger.error(f"LiveAvatar API error: {response_text}")
|
||||
raise LiveAvatarApiError(
|
||||
f"API request failed with status {response.status}",
|
||||
response.status,
|
||||
response_text,
|
||||
)
|
||||
return await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Network error while calling LiveAvatar API: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create_session_token(
|
||||
self, request_data: LiveAvatarNewSessionRequest
|
||||
) -> SessionTokenResponse:
|
||||
"""Create a session token for LiveAvatar.
|
||||
|
||||
https://docs.liveavatar.com/reference/create_session_token_v1_sessions_token_post
|
||||
|
||||
Args:
|
||||
request_data: Session token configuration parameters.
|
||||
|
||||
Returns:
|
||||
Session token information.
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
"mode": request_data.mode,
|
||||
"avatar_id": request_data.avatar_id,
|
||||
}
|
||||
|
||||
# Only include avatar_persona if it exists and has non-None values
|
||||
if request_data.avatar_persona is not None:
|
||||
avatar_persona = {
|
||||
"voice_id": request_data.avatar_persona.voice_id,
|
||||
"context_id": request_data.avatar_persona.context_id,
|
||||
"language": request_data.avatar_persona.language,
|
||||
}
|
||||
# Remove None values from avatar_persona
|
||||
avatar_persona = {k: v for k, v in avatar_persona.items() if v is not None}
|
||||
params["avatar_persona"] = avatar_persona
|
||||
|
||||
response = await self._request("POST", "/sessions/token", params)
|
||||
logger.debug(f"LiveAvatar session token created")
|
||||
|
||||
return SessionTokenResponse.model_validate(response)
|
||||
|
||||
async def start_session(self, session_token: str) -> LiveAvatarSessionResponse:
|
||||
"""Start a new LiveAvatar session.
|
||||
|
||||
https://docs.liveavatar.com/reference/start_session_v1_sessions_start_post
|
||||
|
||||
Args:
|
||||
session_token: Session token obtained from create_session_token.
|
||||
|
||||
Returns:
|
||||
Session information including room URL and session ID.
|
||||
"""
|
||||
response = await self._request("POST", "/sessions/start", bearer_token=session_token)
|
||||
logger.debug(f"LiveAvatar session started")
|
||||
|
||||
return LiveAvatarSessionResponse.model_validate(response)
|
||||
|
||||
async def stop_session(self, session_id: str, session_token: str) -> Any:
|
||||
"""Stop an active LiveAvatar session.
|
||||
|
||||
https://docs.liveavatar.com/reference/stop_session_v1_sessions_stop_post
|
||||
|
||||
Args:
|
||||
session_id: ID of the session to stop.
|
||||
session_token: Session token for authentication.
|
||||
|
||||
Returns:
|
||||
Response data from the stop session API call.
|
||||
|
||||
Raises:
|
||||
ValueError: If session ID is not set.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("Session ID is not set.")
|
||||
|
||||
params = {"session_id": session_id}
|
||||
|
||||
response = await self._request(
|
||||
"POST", "/sessions/stop", params=params, bearer_token=session_token
|
||||
)
|
||||
return response
|
||||
|
||||
async def new_session(
|
||||
self, request_data: LiveAvatarNewSessionRequest
|
||||
) -> StandardSessionResponse:
|
||||
"""Create and start a new LiveAvatar session (convenience method).
|
||||
|
||||
This combines create_session_token and start_session into a single call.
|
||||
|
||||
Args:
|
||||
request_data: Session token configuration parameters.
|
||||
|
||||
Returns:
|
||||
StandardSessionResponse: Standardized session information with LiveAvatar raw response.
|
||||
"""
|
||||
# Create session token
|
||||
token_response = await self.create_session_token(request_data)
|
||||
self._session_token = token_response.data.session_token
|
||||
|
||||
# Start the session using the session_token from the data field
|
||||
session_response = await self.start_session(token_response.data.session_token)
|
||||
|
||||
# Convert to standardized response
|
||||
return StandardSessionResponse(
|
||||
session_id=session_response.data.session_id,
|
||||
access_token=session_response.data.livekit_client_token,
|
||||
livekit_url=session_response.data.livekit_url,
|
||||
# TODO: HeyGen will create a new token for Pipecat
|
||||
# Right now they are creating a single token, which is supposed to be used by the user
|
||||
# Due to this, HeyGenTransport it is not going to work yet.
|
||||
livekit_agent_token=session_response.data.livekit_client_token,
|
||||
ws_url=session_response.data.ws_url,
|
||||
raw_response=session_response,
|
||||
)
|
||||
|
||||
async def close_session(self, session_id: str) -> Any:
|
||||
"""Close an active LiveAvatar session (convenience method).
|
||||
|
||||
This is a convenience method that closes a session using the stored session token
|
||||
from the most recent `new_session()` call. It automatically uses the internally
|
||||
stored session token, eliminating the need to manually track tokens.
|
||||
|
||||
Args:
|
||||
session_id: ID of the session to close.
|
||||
|
||||
Returns:
|
||||
Response data from the stop session API call.
|
||||
|
||||
Raises:
|
||||
ValueError: If no session token is available (i.e., `new_session()`
|
||||
hasn't been called yet or the stored token is None).
|
||||
|
||||
Note:
|
||||
This method requires that `new_session()` has been called previously to
|
||||
establish a stored session token. For more control over session tokens,
|
||||
use `stop_session()` directly with an explicit token parameter.
|
||||
"""
|
||||
if not self._session_token:
|
||||
raise ValueError("Session token is not set. Call new_session first.")
|
||||
|
||||
return await self.stop_session(session_id, self._session_token)
|
||||
67
src/pipecat/services/heygen/base_api.py
Normal file
67
src/pipecat/services/heygen/base_api.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base API for HeyGen avatar services.
|
||||
|
||||
Base class defining the common interface for HeyGen avatar service APIs.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StandardSessionResponse(BaseModel):
|
||||
"""Standardized session response that all HeyGen avatar services will provide.
|
||||
|
||||
This contains the common fields that the client needs to operate,
|
||||
while also storing the raw response for service-specific data access.
|
||||
|
||||
Parameters:
|
||||
session_id (str): Unique identifier for the streaming session.
|
||||
access_token (str): Token for accessing the session securely.
|
||||
livekit_agent_token (str): Token for HeyGen’s audio agents(Pipecat).
|
||||
ws_url (str): WebSocket URL for the session.
|
||||
livekit_url (str): LiveKit server URL for the session.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
access_token: str
|
||||
livekit_agent_token: str
|
||||
|
||||
livekit_url: str = None
|
||||
ws_url: str = None
|
||||
|
||||
raw_response: Any
|
||||
|
||||
|
||||
class BaseAvatarApi(ABC):
|
||||
"""Base class for avatar service APIs."""
|
||||
|
||||
@abstractmethod
|
||||
async def new_session(self, request_data: Any) -> StandardSessionResponse:
|
||||
"""Create a new avatar session.
|
||||
|
||||
Args:
|
||||
request_data: Service-specific session request data
|
||||
|
||||
Returns:
|
||||
StandardSessionResponse: Standardized session information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close_session(self, session_id: str) -> Any:
|
||||
"""Close an avatar session.
|
||||
|
||||
Args:
|
||||
session_id: ID of the session to close
|
||||
|
||||
Returns:
|
||||
Response data from the close session API call
|
||||
"""
|
||||
pass
|
||||
@@ -16,7 +16,8 @@ import base64
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Awaitable, Callable, Optional
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -28,7 +29,12 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameProcessorSetup
|
||||
from pipecat.services.heygen.api import HeyGenApi, HeyGenSession, NewSessionRequest
|
||||
from pipecat.services.heygen.api_interactive_avatar import HeyGenApi, NewSessionRequest
|
||||
from pipecat.services.heygen.api_liveavatar import (
|
||||
LiveAvatarApi,
|
||||
LiveAvatarNewSessionRequest,
|
||||
)
|
||||
from pipecat.services.heygen.base_api import StandardSessionResponse
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
@@ -45,6 +51,13 @@ except ModuleNotFoundError as e:
|
||||
HEY_GEN_SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
class ServiceType(Enum):
|
||||
"""Enum for HeyGen service types."""
|
||||
|
||||
INTERACTIVE_AVATAR = "INTERACTIVE_AVATAR"
|
||||
LIVE_AVATAR = "LIVE_AVATAR"
|
||||
|
||||
|
||||
class HeyGenCallbacks(BaseModel):
|
||||
"""Callback handlers for HeyGen events.
|
||||
|
||||
@@ -78,10 +91,8 @@ class HeyGenClient:
|
||||
api_key: str,
|
||||
session: aiohttp.ClientSession,
|
||||
params: TransportParams,
|
||||
session_request: NewSessionRequest = NewSessionRequest(
|
||||
avatarName="Shawn_Therapist_public",
|
||||
version="v2",
|
||||
),
|
||||
session_request: Optional[Union[LiveAvatarNewSessionRequest, NewSessionRequest]] = None,
|
||||
service_type: Optional[ServiceType] = None,
|
||||
callbacks: HeyGenCallbacks,
|
||||
connect_as_user: bool = False,
|
||||
) -> None:
|
||||
@@ -91,12 +102,52 @@ class HeyGenClient:
|
||||
api_key: HeyGen API key for authentication
|
||||
session: HTTP client session for API requests
|
||||
params: Transport configuration parameters
|
||||
session_request: Configuration for the HeyGen session (default: uses Shawn_Therapist_public avatar)
|
||||
session_request: Configuration for the HeyGen session (optional)
|
||||
service_type: Type of service to use
|
||||
callbacks: Callback handlers for HeyGen events
|
||||
connect_as_user: Whether to connect using the user token or not (default: False)
|
||||
"""
|
||||
self._api = HeyGenApi(api_key, session=session)
|
||||
self._heyGen_session: Optional[HeyGenSession] = None
|
||||
# Set default service type for backwards compatibility
|
||||
self._service_type = (
|
||||
service_type if service_type is not None else ServiceType.INTERACTIVE_AVATAR
|
||||
)
|
||||
|
||||
# Validate session_request matches service_type if both are provided
|
||||
if session_request is not None and service_type is not None:
|
||||
if service_type == ServiceType.LIVE_AVATAR and not isinstance(
|
||||
session_request, LiveAvatarNewSessionRequest
|
||||
):
|
||||
logger.warning(
|
||||
f"Service type is LIVE_AVATAR but session_request is not SessionTokenRequest. Ignoring session_request."
|
||||
)
|
||||
session_request = None
|
||||
elif service_type == ServiceType.INTERACTIVE_AVATAR and not isinstance(
|
||||
session_request, NewSessionRequest
|
||||
):
|
||||
logger.warning(
|
||||
f"Service type is INTERACTIVE_AVATAR but session_request is not NewSessionRequest. Ignoring session_request."
|
||||
)
|
||||
session_request = None
|
||||
|
||||
# Create default session_request based on service_type if not provided
|
||||
if session_request is None:
|
||||
if self._service_type == ServiceType.INTERACTIVE_AVATAR:
|
||||
session_request = NewSessionRequest(
|
||||
avatar_id="Shawn_Therapist_public",
|
||||
version="v2",
|
||||
)
|
||||
else: # LIVE_AVATAR
|
||||
session_request = LiveAvatarNewSessionRequest(
|
||||
avatar_id="1c690fe7-23e0-49f9-bfba-14344450285b"
|
||||
)
|
||||
|
||||
# Initialize API based on service type
|
||||
if self._service_type == ServiceType.INTERACTIVE_AVATAR:
|
||||
self._api = HeyGenApi(api_key, session=session)
|
||||
else:
|
||||
self._api = LiveAvatarApi(api_key, session=session)
|
||||
|
||||
self._heyGen_session: Optional[StandardSessionResponse] = None
|
||||
self._websocket = None
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._params = params
|
||||
@@ -130,14 +181,12 @@ class HeyGenClient:
|
||||
async def _initialize(self):
|
||||
self._heyGen_session = await self._api.new_session(self._session_request)
|
||||
logger.debug(f"HeyGen sessionId: {self._heyGen_session.session_id}")
|
||||
logger.debug(f"HeyGen realtime_endpoint: {self._heyGen_session.realtime_endpoint}")
|
||||
logger.debug(f"HeyGen livekit URL: {self._heyGen_session.url}")
|
||||
logger.debug(f"HeyGen livekit toke: {self._heyGen_session.access_token}")
|
||||
logger.debug(f"HeyGen realtime_endpoint: {self._heyGen_session.ws_url}")
|
||||
logger.debug(f"HeyGen livekit URL: {self._heyGen_session.livekit_url}")
|
||||
logger.debug(f"HeyGen livekit token: {self._heyGen_session.access_token}")
|
||||
logger.info(
|
||||
f"Full Link: https://meet.livekit.io/custom?liveKitUrl={self._heyGen_session.url}&token={self._heyGen_session.access_token}"
|
||||
f"Full Link: https://meet.livekit.io/custom?liveKitUrl={self._heyGen_session.livekit_url}&token={self._heyGen_session.access_token}"
|
||||
)
|
||||
|
||||
await self._api.start_session(self._heyGen_session.session_id)
|
||||
logger.info("HeyGen session started")
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup) -> None:
|
||||
@@ -222,7 +271,7 @@ class HeyGenClient:
|
||||
return
|
||||
logger.debug(f"HeyGenClient ws connecting")
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self._heyGen_session.realtime_endpoint,
|
||||
uri=self._heyGen_session.ws_url,
|
||||
)
|
||||
self._connected = True
|
||||
self._receive_task = self._task_manager.create_task(
|
||||
@@ -509,7 +558,9 @@ class HeyGenClient:
|
||||
async def _livekit_connect(self):
|
||||
"""Connect to LiveKit room."""
|
||||
try:
|
||||
logger.debug(f"HeyGenClient livekit connecting to room URL: {self._heyGen_session.url}")
|
||||
logger.debug(
|
||||
f"HeyGenClient livekit connecting to room URL: {self._heyGen_session.livekit_url}"
|
||||
)
|
||||
self._livekit_room = rtc.Room()
|
||||
|
||||
@self._livekit_room.on("participant_connected")
|
||||
@@ -574,7 +625,8 @@ class HeyGenClient:
|
||||
if not self._connect_as_user
|
||||
else self._heyGen_session.access_token
|
||||
)
|
||||
await self._livekit_room.connect(self._heyGen_session.url, access_token)
|
||||
|
||||
await self._livekit_room.connect(self._heyGen_session.livekit_url, access_token)
|
||||
logger.debug(f"Successfully connected to LiveKit room: {self._livekit_room.name}")
|
||||
logger.debug(f"Local participant SID: {self._livekit_room.local_participant.sid}")
|
||||
logger.debug(
|
||||
|
||||
@@ -12,7 +12,7 @@ audio/video streaming capabilities through the HeyGen API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -37,8 +37,14 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.heygen.api import NewSessionRequest
|
||||
from pipecat.services.heygen.client import HEY_GEN_SAMPLE_RATE, HeyGenCallbacks, HeyGenClient
|
||||
from pipecat.services.heygen.api_interactive_avatar import NewSessionRequest
|
||||
from pipecat.services.heygen.api_liveavatar import LiveAvatarNewSessionRequest
|
||||
from pipecat.services.heygen.client import (
|
||||
HEY_GEN_SAMPLE_RATE,
|
||||
HeyGenCallbacks,
|
||||
HeyGenClient,
|
||||
ServiceType,
|
||||
)
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
# Using the same values that we do in the BaseOutputTransport
|
||||
@@ -72,7 +78,8 @@ class HeyGenVideoService(AIService):
|
||||
*,
|
||||
api_key: str,
|
||||
session: aiohttp.ClientSession,
|
||||
session_request: NewSessionRequest = NewSessionRequest(avatar_id="Shawn_Therapist_public"),
|
||||
session_request: Optional[Union[LiveAvatarNewSessionRequest, NewSessionRequest]] = None,
|
||||
service_type: Optional[ServiceType] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize the HeyGen video service.
|
||||
@@ -80,7 +87,8 @@ class HeyGenVideoService(AIService):
|
||||
Args:
|
||||
api_key: HeyGen API key for authentication
|
||||
session: HTTP client session for API requests
|
||||
session_request: Configuration for the HeyGen session (default: uses Shawn_Therapist_public avatar)
|
||||
session_request: Configuration for the HeyGen session
|
||||
service_type: Service type for the avatar session
|
||||
**kwargs: Additional arguments passed to parent AIService
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
@@ -91,6 +99,7 @@ class HeyGenVideoService(AIService):
|
||||
self._resampler = create_stream_resampler()
|
||||
self._is_interrupting = False
|
||||
self._session_request = session_request
|
||||
self._service_type = service_type
|
||||
self._other_participant_has_joined = False
|
||||
self._event_id = None
|
||||
self._audio_chunk_size = 0
|
||||
@@ -117,6 +126,7 @@ class HeyGenVideoService(AIService):
|
||||
audio_out_sample_rate=HEY_GEN_SAMPLE_RATE,
|
||||
),
|
||||
session_request=self._session_request,
|
||||
service_type=self._service_type,
|
||||
callbacks=HeyGenCallbacks(
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
|
||||
@@ -8,10 +8,14 @@ import base64
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat import version as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
@@ -26,11 +30,7 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from hume import AsyncHumeClient
|
||||
from hume.tts import (
|
||||
FormatPcm,
|
||||
PostedUtterance,
|
||||
PostedUtteranceVoiceWithId,
|
||||
)
|
||||
from hume.tts import FormatPcm, PostedUtterance, PostedUtteranceVoiceWithId
|
||||
from hume.tts.types import TimestampMessage
|
||||
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -40,6 +40,12 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
|
||||
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
|
||||
|
||||
# Tracking headers for Hume API requests
|
||||
DEFAULT_HEADERS = {
|
||||
"X-Hume-Client-Name": "pipecat",
|
||||
"X-Hume-Client-Version": pipecat_version(),
|
||||
}
|
||||
|
||||
|
||||
class HumeTTSService(WordTTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
@@ -104,7 +110,11 @@ class HumeTTSService(WordTTSService):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key)
|
||||
# Create a custom httpx.AsyncClient with tracking headers
|
||||
# Headers are included in all requests made by the Hume SDK
|
||||
self._http_client = httpx.AsyncClient(headers=DEFAULT_HEADERS)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key, httpx_client=self._http_client)
|
||||
self._params = params or HumeTTSService.InputParams()
|
||||
|
||||
# Store voice in the base class (mirrors other services)
|
||||
@@ -138,6 +148,26 @@ class HumeTTSService(WordTTSService):
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
"""Stop the service and cleanup resources.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if hasattr(self, "_http_client") and self._http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
"""Cancel the service and cleanup resources.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
if hasattr(self, "_http_client") and self._http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
@@ -215,7 +245,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
# Start TTS sequence if not already started
|
||||
if not self._started:
|
||||
self.start_word_timestamps()
|
||||
await self.start_word_timestamps()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
@@ -44,7 +45,11 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -166,20 +171,17 @@ class LLMService(AIService):
|
||||
# However, subclasses should override this with a more specific adapter when necessary.
|
||||
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||
|
||||
def __init__(self, run_in_parallel: bool = True, wait_for_all: bool = False, **kwargs):
|
||||
def __init__(self, run_in_parallel: bool = True, **kwargs):
|
||||
"""Initialize the LLM service.
|
||||
|
||||
Args:
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
wait_for_all: Whether to wait for all function calls (parallel or
|
||||
sequential) to complete. Defaults to False.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._wait_for_all = wait_for_all
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
@@ -546,29 +548,10 @@ class LLMService(AIService):
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
|
||||
if self._wait_for_all:
|
||||
# Protect gather from being cancelled. This will protect all tasks
|
||||
# form being cancelled. That is fine, because we cancel them
|
||||
# explicitly when handling the interruption (InterruptionFrame). We
|
||||
# need to set `return_exceptions=True` because `asyncio.shield()`
|
||||
# will get cancelled (from FrameProcessor process task), then
|
||||
# `asyncio.gather()` will keep running (because it was protected by
|
||||
# the shield). Then, individiaul function call tasks will be
|
||||
# cancelled by us and we don't need to propagate those
|
||||
# CancelledErrors at that point.
|
||||
await asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
|
||||
|
||||
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
if self._wait_for_all:
|
||||
# Run each function call sequentially, waiting for each to complete.
|
||||
for runner_item in runner_items:
|
||||
self._function_call_tasks[None] = runner_item
|
||||
await self._run_function_call(runner_item)
|
||||
del self._function_call_tasks[None]
|
||||
else:
|
||||
# Enqueue all function calls for background execution.
|
||||
for runner_item in runner_items:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
# Enqueue all function calls for background execution.
|
||||
for runner_item in runner_items:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, TypeAlias
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -46,17 +46,24 @@ class MCPClient(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
tools_filter: Optional[List[str]] = None,
|
||||
tools_output_filters: Optional[Dict[str, Callable[[Any], Any]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
Args:
|
||||
server_params: Server connection parameters (stdio or SSE).
|
||||
tools_filter: Optional list of tool names to register. If None, all tools are registered.
|
||||
tools_output_filters: Optional dict mapping tool names to filter functions that process tool outputs.
|
||||
Each filter function receives the raw tool output (any type) and returns the processed output (any type).
|
||||
**kwargs: Additional arguments passed to the parent BaseObject.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._server_params = server_params
|
||||
self._session = ClientSession
|
||||
self._tools_filter = tools_filter
|
||||
self._tools_output_filters = tools_output_filters or {}
|
||||
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
self._client = stdio_client
|
||||
@@ -264,13 +271,26 @@ class MCPClient(BaseObject):
|
||||
else:
|
||||
# logger.debug(f"Non-text result content: '{content}'")
|
||||
pass
|
||||
logger.info(f"Tool '{function_name}' completed successfully")
|
||||
logger.debug(f"Final response: {response}")
|
||||
else:
|
||||
logger.error(f"Error getting content from {function_name} results.")
|
||||
|
||||
final_response = response if len(response) else "Sorry, could not call the mcp tool"
|
||||
await result_callback(final_response)
|
||||
# Apply output filter if configured for this tool
|
||||
if function_name in self._tools_output_filters:
|
||||
try:
|
||||
response = self._tools_output_filters[function_name](response)
|
||||
logger.debug(f"Final response (after filter): {response}")
|
||||
|
||||
except Exception:
|
||||
logger.error(f"Error applying output filter for {function_name}")
|
||||
response = ""
|
||||
|
||||
if response and len(response) and isinstance(response, str):
|
||||
logger.info(f"Tool '{function_name}' completed successfully")
|
||||
logger.debug(f"Final response: {response}")
|
||||
else:
|
||||
response = "Sorry, could not call the mcp tool"
|
||||
|
||||
await result_callback(response)
|
||||
|
||||
async def _list_tools_helper(self, session):
|
||||
available_tools = await session.list_tools()
|
||||
@@ -283,6 +303,12 @@ class MCPClient(BaseObject):
|
||||
|
||||
for tool in available_tools.tools:
|
||||
tool_name = tool.name
|
||||
|
||||
# Apply tools filter if configured
|
||||
if self._tools_filter and tool_name not in self._tools_filter:
|
||||
logger.debug(f"Skipping tool '{tool_name}' - not in allowed tools list")
|
||||
continue
|
||||
|
||||
logger.debug(f"Processing tool: {tool_name}")
|
||||
logger.debug(f"Tool description: {tool.description}")
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ from PIL import Image
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
VisionFullResponseEndFrame,
|
||||
VisionFullResponseStartFrame,
|
||||
VisionTextFrame,
|
||||
)
|
||||
from pipecat.services.vision_service import VisionService
|
||||
|
||||
@@ -104,10 +106,6 @@ class MoondreamService(VisionService):
|
||||
|
||||
Args:
|
||||
frame: The image frame to process.
|
||||
|
||||
Yields:
|
||||
Frame: TextFrame containing the generated image description, or ErrorFrame
|
||||
if analysis fails.
|
||||
"""
|
||||
if not self._model:
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
@@ -123,4 +121,6 @@ class MoondreamService(VisionService):
|
||||
|
||||
description = await asyncio.to_thread(get_image_description, frame.image, frame.text)
|
||||
|
||||
yield TextFrame(text=description)
|
||||
yield VisionFullResponseStartFrame()
|
||||
yield VisionTextFrame(text=description)
|
||||
yield VisionFullResponseEndFrame()
|
||||
|
||||
@@ -133,6 +133,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
self._retry_timeout_secs = retry_timeout_secs
|
||||
self._retry_on_timeout = retry_on_timeout
|
||||
self.set_model_name(model)
|
||||
self._full_model_name: str = ""
|
||||
self._client = self.create_client(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
@@ -185,6 +186,22 @@ class BaseOpenAILLMService(LLMService):
|
||||
"""
|
||||
return True
|
||||
|
||||
def set_full_model_name(self, full_model_name: str):
|
||||
"""Set the full AI model name.
|
||||
|
||||
Args:
|
||||
full_model_name: The full name of the AI model to use.
|
||||
"""
|
||||
self._full_model_name = full_model_name
|
||||
|
||||
def get_full_model_name(self):
|
||||
"""Get the current full model name.
|
||||
|
||||
Returns:
|
||||
The full name of the AI model being used.
|
||||
"""
|
||||
return self._full_model_name
|
||||
|
||||
async def get_chat_completions(
|
||||
self, params_from_context: OpenAILLMInvocationParams
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
@@ -259,17 +276,23 @@ class BaseOpenAILLMService(LLMService):
|
||||
"""
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
invocation_params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context
|
||||
)
|
||||
else:
|
||||
messages = context.messages
|
||||
invocation_params = OpenAILLMInvocationParams(
|
||||
messages=context.messages, tools=context.tools, tool_choice=context.tool_choice
|
||||
)
|
||||
|
||||
# Build params using the same method as streaming completions
|
||||
params = self.build_chat_completion_params(invocation_params)
|
||||
|
||||
# Override for non-streaming
|
||||
params["stream"] = False
|
||||
params.pop("stream_options", None)
|
||||
|
||||
# LLM completion
|
||||
response = await self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
response = await self._client.chat.completions.create(**params)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -360,6 +383,9 @@ class BaseOpenAILLMService(LLMService):
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.model and self.get_full_model_name() != chunk.model:
|
||||
self.set_full_model_name(chunk.model)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
|
||||
@@ -31,7 +31,11 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import AudioContextWordTTSService, TTSService
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextWordTTSService,
|
||||
InterruptibleTTSService,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
@@ -381,7 +385,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
if msg["type"] == "chunk":
|
||||
# Process audio chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -608,3 +612,332 @@ class RimeHttpTTSService(TTSService):
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
"""Pipecat TTS service for Rime's non-JSON WebSocket API.
|
||||
|
||||
This service enables Text-to-Speech synthesis over WebSocket endpoints
|
||||
that require plain text (not JSON) messages and return raw audio bytes.
|
||||
It is designed for use with TTS models like Arcana, which currently do
|
||||
not support JSON-based WebSocket protocols (though this may change in
|
||||
the future).
|
||||
|
||||
Limitations:
|
||||
- Does not support word-level timestamps or context IDs.
|
||||
- Intended specifically for integrations where the TTS provider only
|
||||
accepts and returns non-JSON messages.
|
||||
|
||||
Note:
|
||||
- Arcana and similar models may add JSON WebSocket support in the
|
||||
future. This service focuses on the current plain text protocol.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Rime Non-JSON WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
language: Language for synthesis. Defaults to English.
|
||||
segment: Text segmentation mode ("immediate", "bySentence", "never").
|
||||
repetition_penalty: Token repetition penalty (1.0-2.0).
|
||||
temperature: Sampling temperature (0.0-1.0).
|
||||
top_p: Cumulative probability threshold (0.0-1.0).
|
||||
extra: Additional parameters to pass to the API (for future compatibility).
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
segment: Optional[str] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
url: str = "wss://users.rime.ai/ws",
|
||||
model: str = "arcana",
|
||||
audio_format: str = "pcm",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Rime Non-JSON WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Rime API key for authentication.
|
||||
voice_id: ID of the voice to use.
|
||||
url: Rime websocket API endpoint.
|
||||
model: Model ID to use for synthesis.
|
||||
audio_format: Audio format to use.
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
params: Additional configuration parameters.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
**kwargs,
|
||||
)
|
||||
params = params or RimeNonJsonTTSService.InputParams()
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._model = model
|
||||
self._settings = {
|
||||
"speaker": voice_id,
|
||||
"modelId": model,
|
||||
"audioFormat": audio_format,
|
||||
"samplingRate": sample_rate,
|
||||
}
|
||||
|
||||
if params.language:
|
||||
self._settings["lang"] = self.language_to_service_language(params.language)
|
||||
if params.segment is not None:
|
||||
self._settings["segment"] = params.segment
|
||||
if params.repetition_penalty is not None:
|
||||
self._settings["repetition_penalty"] = params.repetition_penalty
|
||||
if params.temperature is not None:
|
||||
self._settings["temperature"] = params.temperature
|
||||
if params.top_p is not None:
|
||||
self._settings["top_p"] = params.top_p
|
||||
# Add any extra parameters for future compatibility
|
||||
if params.extra:
|
||||
self._settings.update(params.extra)
|
||||
|
||||
self._started = False
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Rime Non-JSON WebSocket service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str:
|
||||
"""Convert pipecat Language enum to Rime language code.
|
||||
|
||||
Args:
|
||||
language: The Language enum value to convert.
|
||||
|
||||
Returns:
|
||||
Three-letter Rime language code (e.g., 'eng' for English).
|
||||
Falls back to the language's base code with a warning if not in the verified list.
|
||||
"""
|
||||
return language_to_rime_language(language)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Rime Non-JSON WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["samplingRate"] = self.sample_rate
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close connection."""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel current operation and clean up."""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream with special handling for stop conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish WebSocket connection and start receive task."""
|
||||
await self._connect_websocket()
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close WebSocket connection and clean up tasks."""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish WebSocket connection to Rime non-JSON websocket."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
# Build URL with query parameters (only non-None values)
|
||||
params = "&".join(f"{k}={v}" for k, v in self._settings.items() if v is not None)
|
||||
url = f"{self._url}?{params}"
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websocket_connect(
|
||||
url, additional_headers=headers, max_size=1024 * 1024 * 16
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close WebSocket connection and clean up state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
# Send EOS command to gracefully close
|
||||
await self._websocket.send("<EOS>")
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from Rime non-JSON websocket")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active WebSocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
return
|
||||
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
await self._websocket.send("<FLUSH>")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming WebSocket messages (raw audio bytes)."""
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
# Rime Arcana sends raw audio bytes directly (not JSON)
|
||||
if isinstance(message, bytes):
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=message,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error: {e}", exception=e)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Rime's streaming API.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
try:
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
# Send bare text (not JSON)
|
||||
await self._get_websocket().send(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if necessary.
|
||||
|
||||
Since all settings are WebSocket URL query parameters,
|
||||
any setting change requires reconnecting to apply the new values.
|
||||
"""
|
||||
needs_reconnect = False
|
||||
|
||||
# Track previous values from self._settings only
|
||||
prev_settings = self._settings.copy()
|
||||
|
||||
# Let parent class handle standard settings (voice, model, language)
|
||||
await super()._update_settings(settings)
|
||||
|
||||
# Check if voice changed and update settings dict
|
||||
if "voice" in settings or "voice_id" in settings:
|
||||
self._settings["speaker"] = self._voice_id
|
||||
if prev_settings.get("speaker") != self._voice_id:
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
needs_reconnect = True
|
||||
|
||||
# Check if model changed and update settings dict
|
||||
if "model" in settings:
|
||||
self._settings["modelId"] = self._model
|
||||
if prev_settings.get("modelId") != self._model:
|
||||
logger.info(f"Switching TTS model to: [{self._model}]")
|
||||
needs_reconnect = True
|
||||
|
||||
# Handle language explicitly
|
||||
if "language" in settings:
|
||||
new_lang = self.language_to_service_language(settings["language"])
|
||||
if new_lang and new_lang != prev_settings.get("lang"):
|
||||
logger.info(f"Updating language to: [{new_lang}]")
|
||||
self._settings["lang"] = new_lang
|
||||
needs_reconnect = True
|
||||
|
||||
# Check other parameters
|
||||
for key in ["segment", "repetition_penalty", "temperature", "top_p"]:
|
||||
if key in settings and settings[key] != prev_settings.get(key):
|
||||
logger.info(f"Updating {key} to: [{settings[key]}]")
|
||||
self._settings[key] = settings[key]
|
||||
needs_reconnect = True
|
||||
|
||||
# Handle extra parameters
|
||||
for key, value in settings.items():
|
||||
if key not in [
|
||||
"voice",
|
||||
"voice_id",
|
||||
"model",
|
||||
"language",
|
||||
"segment",
|
||||
"repetition_penalty",
|
||||
"temperature",
|
||||
"top_p",
|
||||
]:
|
||||
if value != prev_settings.get(key):
|
||||
logger.info(f"Updating extra parameter {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
needs_reconnect = True
|
||||
|
||||
# Reconnect if any setting changed
|
||||
if needs_reconnect:
|
||||
logger.debug("Settings changed, reconnecting WebSocket with new parameters")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
@@ -17,7 +17,6 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -25,13 +24,12 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
@@ -87,6 +85,7 @@ class SonioxInputParams(BaseModel):
|
||||
audio_format: Audio format to use for transcription.
|
||||
num_channels: Number of channels to use for transcription.
|
||||
language_hints: List of language hints to use for transcription.
|
||||
language_hints_strict: If true, strictly enforce language hints (only transcribe in provided languages).
|
||||
context: Customization for transcription. String for models with context_version 1 and ContextObject for models with context_version 2.
|
||||
enable_speaker_diarization: Whether to enable speaker diarization. Tokens are annotated with speaker IDs.
|
||||
enable_language_identification: Whether to enable language identification. Tokens are annotated with language IDs.
|
||||
@@ -99,6 +98,7 @@ class SonioxInputParams(BaseModel):
|
||||
num_channels: Optional[int] = 1
|
||||
|
||||
language_hints: Optional[List[Language]] = None
|
||||
language_hints_strict: Optional[bool] = None
|
||||
context: Optional[SonioxContextObject | str] = None
|
||||
|
||||
enable_speaker_diarization: Optional[bool] = False
|
||||
@@ -134,7 +134,7 @@ def _prepare_language_hints(
|
||||
return list(set(prepared_languages))
|
||||
|
||||
|
||||
class SonioxSTTService(STTService):
|
||||
class SonioxSTTService(WebsocketSTTService):
|
||||
"""Speech-to-Text service using Soniox's WebSocket API.
|
||||
|
||||
This service connects to Soniox's WebSocket API for real-time transcription
|
||||
@@ -173,7 +173,6 @@ class SonioxSTTService(STTService):
|
||||
self.set_model_name(params.model)
|
||||
self._params = params
|
||||
self._vad_force_turn_endpoint = vad_force_turn_endpoint
|
||||
self._websocket = None
|
||||
|
||||
self._final_transcription_buffer = []
|
||||
self._last_tokens_received: Optional[float] = None
|
||||
@@ -188,59 +187,7 @@ class SonioxSTTService(STTService):
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
if self._websocket:
|
||||
return
|
||||
|
||||
self._websocket = await websocket_connect(self._url)
|
||||
|
||||
if not self._websocket:
|
||||
await self.push_error(error_msg=f"Unable to connect to Soniox API at {self._url}")
|
||||
|
||||
# If vad_force_turn_endpoint is not enabled, we need to enable endpoint detection.
|
||||
# Either one or the other is required.
|
||||
enable_endpoint_detection = not self._vad_force_turn_endpoint
|
||||
|
||||
context = self._params.context
|
||||
if isinstance(context, SonioxContextObject):
|
||||
context = context.model_dump()
|
||||
|
||||
# Send the initial configuration message.
|
||||
config = {
|
||||
"api_key": self._api_key,
|
||||
"model": self._model_name,
|
||||
"audio_format": self._params.audio_format,
|
||||
"num_channels": self._params.num_channels or 1,
|
||||
"enable_endpoint_detection": enable_endpoint_detection,
|
||||
"sample_rate": self.sample_rate,
|
||||
"language_hints": _prepare_language_hints(self._params.language_hints),
|
||||
"context": context,
|
||||
"enable_speaker_diarization": self._params.enable_speaker_diarization,
|
||||
"enable_language_identification": self._params.enable_language_identification,
|
||||
"client_reference_id": self._params.client_reference_id,
|
||||
}
|
||||
|
||||
# Send the configuration message.
|
||||
await self._websocket.send(json.dumps(config))
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
if self._websocket and not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _cleanup(self):
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
# Task cannot cancel itself. If task called _cleanup() we expect it to cancel itself.
|
||||
if self._receive_task != asyncio.current_task():
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Soniox STT websocket connection.
|
||||
@@ -253,6 +200,7 @@ class SonioxSTTService(STTService):
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._send_stop_recording()
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Soniox STT websocket connection.
|
||||
@@ -265,7 +213,7 @@ class SonioxSTTService(STTService):
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._cleanup()
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Soniox STT Service.
|
||||
@@ -311,28 +259,111 @@ class SonioxSTTService(STTService):
|
||||
# Send stop recording message
|
||||
await self._websocket.send("")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Connection has to be open all the time."""
|
||||
async def _connect(self):
|
||||
"""Connect to the Soniox service.
|
||||
|
||||
Establishes websocket connection and starts receive and keepalive tasks.
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
if self._websocket and not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the Soniox service.
|
||||
|
||||
Cleans up tasks and closes websocket connection.
|
||||
"""
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish the websocket connection to Soniox."""
|
||||
try:
|
||||
while True:
|
||||
logger.trace("Sending keepalive message")
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(KEEPALIVE_MESSAGE)
|
||||
else:
|
||||
logger.debug("WebSocket connection closed.")
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
# Expected when closing the connection
|
||||
logger.debug("WebSocket connection closed, keepalive task stopped.")
|
||||
logger.debug("Connecting to Soniox STT")
|
||||
|
||||
self._websocket = await websocket_connect(self._url)
|
||||
|
||||
if not self._websocket:
|
||||
await self.push_error(error_msg=f"Unable to connect to Soniox API at {self._url}")
|
||||
raise Exception(f"Unable to connect to Soniox API at {self._url}")
|
||||
|
||||
# If vad_force_turn_endpoint is not enabled, we need to enable endpoint detection.
|
||||
# Either one or the other is required.
|
||||
enable_endpoint_detection = not self._vad_force_turn_endpoint
|
||||
|
||||
context = self._params.context
|
||||
if isinstance(context, SonioxContextObject):
|
||||
context = context.model_dump()
|
||||
|
||||
# Send the initial configuration message.
|
||||
config = {
|
||||
"api_key": self._api_key,
|
||||
"model": self._model_name,
|
||||
"audio_format": self._params.audio_format,
|
||||
"num_channels": self._params.num_channels or 1,
|
||||
"enable_endpoint_detection": enable_endpoint_detection,
|
||||
"sample_rate": self.sample_rate,
|
||||
"language_hints": _prepare_language_hints(self._params.language_hints),
|
||||
"language_hints_strict": self._params.language_hints_strict,
|
||||
"context": context,
|
||||
"enable_speaker_diarization": self._params.enable_speaker_diarization,
|
||||
"enable_language_identification": self._params.enable_language_identification,
|
||||
"client_reference_id": self._params.client_reference_id,
|
||||
}
|
||||
|
||||
# Send the configuration message.
|
||||
await self._websocket.send(json.dumps(config))
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug("Connected to Soniox STT")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self.push_error(error_msg=f"Unable to connect to Soniox: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
if not self._websocket:
|
||||
return
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the websocket connection to Soniox."""
|
||||
try:
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Soniox STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get the current WebSocket connection.
|
||||
|
||||
Returns:
|
||||
The WebSocket connection.
|
||||
|
||||
Raises:
|
||||
Exception: If WebSocket is not connected.
|
||||
"""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process websocket messages.
|
||||
|
||||
Continuously processes messages from the websocket connection.
|
||||
"""
|
||||
# Transcription frame will be only sent after we get the "endpoint" event.
|
||||
self._final_transcription_buffer = []
|
||||
|
||||
@@ -351,8 +382,8 @@ class SonioxSTTService(STTService):
|
||||
await self.stop_processing_metrics()
|
||||
self._final_transcription_buffer = []
|
||||
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
content = json.loads(message)
|
||||
|
||||
tokens = content["tokens"]
|
||||
@@ -404,7 +435,7 @@ class SonioxSTTService(STTService):
|
||||
# In case of error, still send the final transcript (if any remaining in the buffer).
|
||||
await send_endpoint_transcript()
|
||||
await self.push_error(
|
||||
error_msg=f"Error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
error_msg=f"Error: {error_code} (_receive_messages) - {error_message}"
|
||||
)
|
||||
|
||||
finished = content.get("finished")
|
||||
@@ -412,11 +443,24 @@ class SonioxSTTService(STTService):
|
||||
# When finished, still send the final transcript (if any remaining in the buffer).
|
||||
await send_endpoint_transcript()
|
||||
logger.debug("Transcription finished.")
|
||||
await self._cleanup()
|
||||
return
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
# Expected when closing the connection.
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing message: {e}")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Connection has to be open all the time."""
|
||||
try:
|
||||
while True:
|
||||
logger.trace("Sending keepalive message")
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(KEEPALIVE_MESSAGE)
|
||||
else:
|
||||
logger.debug("WebSocket connection closed.")
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error receiving message: {e}", exception=e)
|
||||
logger.debug(f"Keepalive task stopped: {e}")
|
||||
|
||||
@@ -651,15 +651,21 @@ class WordTTSService(TTSService):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._initial_word_timestamp = -1
|
||||
self._initial_word_times = []
|
||||
self._words_task = None
|
||||
self._llm_response_started: bool = False
|
||||
|
||||
def start_word_timestamps(self):
|
||||
async def start_word_timestamps(self):
|
||||
"""Start tracking word timestamps from the current time."""
|
||||
if self._initial_word_timestamp == -1:
|
||||
self._initial_word_timestamp = self.get_clock().get_time()
|
||||
# If we cached some initial word times (because we didn't receive
|
||||
# audio), let's add them now.
|
||||
if self._initial_word_times:
|
||||
await self._add_word_timestamps(self._initial_word_times)
|
||||
self._initial_word_times = []
|
||||
|
||||
def reset_word_timestamps(self):
|
||||
async def reset_word_timestamps(self):
|
||||
"""Reset word timestamp tracking."""
|
||||
self._initial_word_timestamp = -1
|
||||
|
||||
@@ -669,8 +675,12 @@ class WordTTSService(TTSService):
|
||||
Args:
|
||||
word_times: List of (word, timestamp) tuples where timestamp is in seconds.
|
||||
"""
|
||||
for word, timestamp in word_times:
|
||||
await self._words_queue.put((word, seconds_to_nanoseconds(timestamp)))
|
||||
if self._initial_word_timestamp == -1:
|
||||
# Cache word timestamps and don't add them until we have started
|
||||
# (i.e. we have some audio).
|
||||
self._initial_word_times.extend(word_times)
|
||||
else:
|
||||
await self._add_word_timestamps(word_times)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the word TTS service.
|
||||
@@ -716,7 +726,7 @@ class WordTTSService(TTSService):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._llm_response_started = False
|
||||
self.reset_word_timestamps()
|
||||
await self.reset_word_timestamps()
|
||||
|
||||
def _create_words_task(self):
|
||||
if not self._words_task:
|
||||
@@ -728,13 +738,17 @@ class WordTTSService(TTSService):
|
||||
await self.cancel_task(self._words_task)
|
||||
self._words_task = None
|
||||
|
||||
async def _add_word_timestamps(self, word_times: List[Tuple[str, float]]):
|
||||
for word, timestamp in word_times:
|
||||
await self._words_queue.put((word, seconds_to_nanoseconds(timestamp)))
|
||||
|
||||
async def _words_task_handler(self):
|
||||
last_pts = 0
|
||||
while True:
|
||||
frame = None
|
||||
(word, timestamp) = await self._words_queue.get()
|
||||
if word == "Reset" and timestamp == 0:
|
||||
self.reset_word_timestamps()
|
||||
await self.reset_word_timestamps()
|
||||
if self._llm_response_started:
|
||||
self._llm_response_started = False
|
||||
frame = LLMFullResponseEndFrame()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .stt import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "ultravox", "ultravox.stt")
|
||||
|
||||
549
src/pipecat/services/ultravox/llm.py
Normal file
549
src/pipecat/services/ultravox/llm.py
Normal file
@@ -0,0 +1,549 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Ultravox Realtime API service implementation.
|
||||
|
||||
This module provides real-time conversational AI capabilities using Ultravox's
|
||||
Realtime API, supporting both text and audio modalities with
|
||||
voice transcription, streaming responses, and tool usage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from openai.types import chat as openai_chat_types
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.utils import create_stream_resampler
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputTextRawFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
UserAudioRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from websockets.asyncio import client as websocket_client
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Ultravox, you need to `pip install pipecat-ai[ultravox]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AgentInputParams(BaseModel):
|
||||
"""Input parameters for Ultravox Realtime generation using a pre-defined Agent.
|
||||
|
||||
Parameters:
|
||||
api_key: Ultravox API key for authentication.
|
||||
agent_id: The ID of the Ultravox Realtime agent you'd like to use. Agents
|
||||
are pre-configured to handle calls consistently. You can create and edit
|
||||
agents in the Ultravox console (https://app.ultravox.ai/agents) or using
|
||||
the Ultravox API (https://docs.ultravox.ai/api-reference/agents/agents-post).
|
||||
template_context: Context variables to use when instantiating a call with the
|
||||
agent. Defaults to an empty dict.
|
||||
metadata: Metadata to attach to the call. Default to an empty dict.
|
||||
max_duration: The maximum duration of the call. Defaults to None, which will
|
||||
use the agent's default maximum duration.
|
||||
extra: Extra parameters to include in the agent call creation request. Defaults
|
||||
to an empty dict. See the Ultravox API documentation for valid arguments:
|
||||
https://docs.ultravox.ai/api-reference/agents/agents-calls-post
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
agent_id: uuid.UUID
|
||||
template_context: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Dict[str, str] = Field(default_factory=dict)
|
||||
max_duration: Optional[datetime.timedelta] = Field(
|
||||
default=None, ge=datetime.timedelta(seconds=10), le=datetime.timedelta(hours=1)
|
||||
)
|
||||
extra: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class OneShotInputParams(BaseModel):
|
||||
"""Input parameters for Ultravox Realtime generation using a one-off call.
|
||||
|
||||
Parameters:
|
||||
api_key: Ultravox API key for authentication.
|
||||
system_prompt: System prompt to guide the model's behavior. Defaults to None.
|
||||
temperature: Sampling temperature for response generation. Defaults to 0.
|
||||
model: Model identifier to use. Defaults to "fixie-ai/ultravox".
|
||||
voice: Voice identifier for speech generation. Defaults to None.
|
||||
metadata: Metadata to attach to the call. Default to an empty dict.
|
||||
max_duration: The maximum duration of the call. Defaults to one hour.
|
||||
extra: Extra parameters to include in the call creation request. Defaults
|
||||
to an empty dict. See the Ultravox API documentation for valid arguments:
|
||||
https://docs.ultravox.ai/api-reference/calls/calls-post
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
system_prompt: Optional[str] = None
|
||||
temperature: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||
model: Optional[str] = None
|
||||
voice: Optional[uuid.UUID] = None
|
||||
metadata: Dict[str, str] = Field(default_factory=dict)
|
||||
max_duration: datetime.timedelta = Field(
|
||||
default=datetime.timedelta(hours=1),
|
||||
ge=datetime.timedelta(seconds=10),
|
||||
le=datetime.timedelta(hours=1),
|
||||
)
|
||||
extra: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class JoinUrlInputParams(BaseModel):
|
||||
"""Input parameters for joining an existing Ultravox Realtime call via join URL.
|
||||
|
||||
Parameters:
|
||||
join_url: The join URL for the existing Ultravox Realtime call.
|
||||
"""
|
||||
|
||||
join_url: str
|
||||
|
||||
|
||||
class UltravoxRealtimeLLMService(LLMService):
|
||||
"""Provides access to the Ultravox Realtime API.
|
||||
|
||||
This service enables real-time conversations with Ultravox, supporting both
|
||||
text and audio output. It handles voice transcription, streaming audio
|
||||
responses, and tool usage.
|
||||
|
||||
Note: Ultravox is an audio-native model, so voice transcriptions are not used
|
||||
by the model and may not always align with its understanding of user input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
params: Union[AgentInputParams, OneShotInputParams, JoinUrlInputParams],
|
||||
one_shot_selected_tools: Optional[ToolsSchema] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Ultravox Realtime LLM service.
|
||||
|
||||
Args:
|
||||
api_key: Ultravox API key for authentication.
|
||||
params: Configuration parameters for the model.
|
||||
one_shot_selected_tools: ToolsSchema for tools to use with this call.
|
||||
May only be set with OneShotInputParams.
|
||||
**kwargs: Additional arguments passed to parent LLMService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._params = params
|
||||
if one_shot_selected_tools:
|
||||
if not isinstance(self._params, OneShotInputParams):
|
||||
logger.warning(
|
||||
"one_shot_selected_tools may only be set when using OneShotInputParams; ignoring."
|
||||
)
|
||||
else:
|
||||
self._selected_tools = one_shot_selected_tools
|
||||
|
||||
self._socket: Optional[websocket_client.ClientConnection] = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._disconnecting = False
|
||||
self._bot_responding: Literal[None, "text", "voice"] = None
|
||||
self._last_user_id: Optional[str] = None
|
||||
|
||||
self._sample_rate = 48000
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
try:
|
||||
match self._params:
|
||||
case JoinUrlInputParams():
|
||||
join_url = self._params.join_url
|
||||
case AgentInputParams():
|
||||
join_url = await self._start_agent_call(self._params)
|
||||
case OneShotInputParams():
|
||||
join_url = await self._start_one_shot_call(self._params)
|
||||
|
||||
logger.info(f"Joining Ultravox Realtime call via URL: {join_url}")
|
||||
self._socket = await websocket_client.connect(join_url)
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
except Exception as e:
|
||||
await self.push_error("Failed to connect to Ultravox", e, fatal=True)
|
||||
|
||||
async def _start_agent_call(self, params: AgentInputParams) -> str:
|
||||
request_body = {
|
||||
"templateContext": params.template_context,
|
||||
"metadata": params.metadata,
|
||||
"medium": {
|
||||
"serverWebSocket": {
|
||||
"inputSampleRate": self._sample_rate,
|
||||
}
|
||||
},
|
||||
}
|
||||
if params.max_duration:
|
||||
request_body["maxDuration"] = f"{params.max_duration.total_seconds():3f}s"
|
||||
request_body = request_body | params.extra
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://api.ultravox.ai/api/agents/{params.agent_id}/calls",
|
||||
headers={"X-Api-Key": params.api_key},
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ultravox API error {response.status}: {error_text}")
|
||||
return (await response.json())["joinUrl"]
|
||||
|
||||
async def _start_one_shot_call(self, params: OneShotInputParams) -> str:
|
||||
request_body = {
|
||||
"systemPrompt": params.system_prompt,
|
||||
"temperature": params.temperature,
|
||||
"model": params.model,
|
||||
"voice": str(params.voice) if params.voice else None,
|
||||
"metadata": params.metadata,
|
||||
"maxDuration": f"{params.max_duration.total_seconds():3f}s",
|
||||
"selectedTools": self._to_selected_tools(self._selected_tools)
|
||||
if self._selected_tools
|
||||
else [],
|
||||
"medium": {
|
||||
"serverWebSocket": {
|
||||
"inputSampleRate": self._sample_rate,
|
||||
}
|
||||
},
|
||||
} | params.extra
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.ultravox.ai/api/calls",
|
||||
headers={"X-Api-Key": params.api_key},
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ultravox API error {response.status}: {error_text}")
|
||||
return (await response.json())["joinUrl"]
|
||||
|
||||
def _to_selected_tools(self, tool: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
result: List[Dict[str, Any]] = []
|
||||
for standard_tool in tool.standard_tools:
|
||||
result.append(
|
||||
{
|
||||
"temporaryTool": {
|
||||
"modelToolName": standard_tool.name,
|
||||
"description": standard_tool.description,
|
||||
"dynamicParameters": [
|
||||
{
|
||||
"name": k,
|
||||
"location": "PARAMETER_LOCATION_BODY",
|
||||
"schema": v,
|
||||
"required": k in standard_tool.required,
|
||||
}
|
||||
for k, v in standard_tool.properties.items()
|
||||
],
|
||||
"client": {},
|
||||
}
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close connections.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the service and close connections.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _disconnect(self):
|
||||
self._disconnecting = True
|
||||
if self._socket:
|
||||
await self._socket.close()
|
||||
self._socket = None
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
|
||||
#
|
||||
# frame processing
|
||||
# StartFrame, StopFrame, CancelFrame implemented in base class
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames for the Ultravox Realtime service.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
context = (
|
||||
frame.context
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
await self._handle_context(context)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
if "output_medium" in frame.settings:
|
||||
await self._update_output_medium(frame.settings.get("output_medium"))
|
||||
elif isinstance(frame, InputTextRawFrame):
|
||||
await self._send_user_text(frame.text)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._send_user_audio(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
# Ultravox handles all context server-side, so the only context we may
|
||||
# need to handle here is new function call results.
|
||||
for message in reversed(context.messages):
|
||||
if message.get("role") != "tool":
|
||||
break
|
||||
content = message.get("content")
|
||||
socket_message = {
|
||||
"type": "client_tool_result",
|
||||
"invocationId": message.get("tool_call_id"),
|
||||
"result": content
|
||||
if isinstance(content, str)
|
||||
else "".join(t.get("text") for t in content),
|
||||
}
|
||||
await self._send(socket_message)
|
||||
|
||||
async def _send_user_audio(self, frame: InputAudioRawFrame):
|
||||
"""Send user audio frame to Ultravox Realtime."""
|
||||
if not self._socket:
|
||||
return
|
||||
self._last_user_id = frame.user_id if isinstance(frame, UserAudioRawFrame) else None
|
||||
audio = frame.audio
|
||||
if frame.sample_rate != self._sample_rate:
|
||||
audio = await self._resampler.resample(audio, frame.sample_rate, self._sample_rate)
|
||||
await self._send(audio)
|
||||
|
||||
async def _send_user_text(self, text: str):
|
||||
"""Send user text via Ultravox Realtime.
|
||||
|
||||
Args:
|
||||
text: The text to send as user input.
|
||||
"""
|
||||
if not self._socket:
|
||||
return
|
||||
await self._send({"type": "user_text_message", "text": text})
|
||||
|
||||
async def _update_output_medium(self, output_medium: str):
|
||||
output_medium = output_medium.lower()
|
||||
if output_medium == "audio":
|
||||
output_medium = "voice"
|
||||
if output_medium.lower() not in {"voice", "text"}:
|
||||
logger.warning(f"Unsupported Ultravox output medium: {output_medium}")
|
||||
return
|
||||
await self._send({"type": "set_output_medium", "medium": output_medium})
|
||||
|
||||
async def _send(self, content: Union[bytes, Dict[str, Any]]):
|
||||
"""Send content via the WebSocket connection.
|
||||
|
||||
Args:
|
||||
content: The content to send, either as bytes or a JSON-serializable dict.
|
||||
"""
|
||||
if self._disconnecting or not self._socket:
|
||||
return
|
||||
|
||||
try:
|
||||
if isinstance(content, bytes):
|
||||
await self._socket.send(content)
|
||||
else:
|
||||
await self._socket.send(json.dumps(content))
|
||||
except Exception as e:
|
||||
if self._disconnecting or not self._socket:
|
||||
return
|
||||
await self.push_error("Ultravox websocket send error", e, fatal=True)
|
||||
|
||||
#
|
||||
# response handling
|
||||
#
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive messages from the Ultravox Realtime WebSocket."""
|
||||
if not self._socket:
|
||||
return
|
||||
async for message in self._socket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
await self._handle_audio(message)
|
||||
continue
|
||||
|
||||
data = json.loads(message)
|
||||
match data.get("type"):
|
||||
case "state":
|
||||
if self._bot_responding and data.get("state") != "speaking":
|
||||
await self._handle_response_end()
|
||||
case "client_tool_invocation":
|
||||
await self._handle_tool_invocation(
|
||||
data.get("toolName"), data.get("invocationId"), data.get("parameters")
|
||||
)
|
||||
case "transcript":
|
||||
match data.get("role"):
|
||||
case "user":
|
||||
if not data.get("final"):
|
||||
logger.warning(
|
||||
"Unexpected non-final user transcript from Ultravox Realtime; ignoring."
|
||||
)
|
||||
else:
|
||||
await self._handle_user_transcript(data.get("text"))
|
||||
case "agent":
|
||||
await self._handle_agent_transcript(
|
||||
data.get("medium"),
|
||||
data.get("text"),
|
||||
data.get("delta"),
|
||||
data.get("final", False),
|
||||
)
|
||||
case _:
|
||||
logger.debug(
|
||||
f"Received transcript with unknown role from Ultravox Realtime: {data}"
|
||||
)
|
||||
case _:
|
||||
logger.debug(f"Received unhandled Ultravox message: {data}")
|
||||
except Exception as e:
|
||||
if self._disconnecting or not self._socket:
|
||||
return
|
||||
await self.push_error("Ultravox websocket receive error", e, fatal=True)
|
||||
|
||||
async def _handle_audio(self, audio: bytes):
|
||||
"""Handle incoming audio bytes from Ultravox Realtime."""
|
||||
if not audio:
|
||||
return
|
||||
if not self._bot_responding:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
self._bot_responding = "voice"
|
||||
await self.push_frame(TTSAudioRawFrame(audio, self._sample_rate, 1))
|
||||
|
||||
async def _handle_response_end(self):
|
||||
if self._bot_responding == "voice":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._bot_responding = None
|
||||
|
||||
async def _handle_tool_invocation(
|
||||
self, tool_name: str, invocation_id: str, parameters: Dict[str, Any]
|
||||
):
|
||||
await self.run_function_calls(
|
||||
[
|
||||
FunctionCallFromLLM(
|
||||
function_name=tool_name,
|
||||
tool_call_id=invocation_id,
|
||||
arguments=parameters,
|
||||
context=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def _handle_user_transcript(self, text: str):
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
user_id=self._last_user_id or "",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=text,
|
||||
text=text,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
async def _handle_agent_transcript(
|
||||
self, medium: str, text: Optional[str], delta: Optional[str], final: bool
|
||||
):
|
||||
if text or delta:
|
||||
frame = LLMTextFrame(text=text or delta)
|
||||
frame.skip_tts = medium == "voice"
|
||||
await self.push_frame(frame)
|
||||
if medium == "text":
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(TTSTextFrame(text=text, aggregated_by=AggregationType.WORD))
|
||||
self._bot_responding = "text"
|
||||
elif final:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._bot_responding = None
|
||||
elif delta:
|
||||
await self.push_frame(TTSTextFrame(text=delta, aggregated_by=AggregationType.WORD))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create an instance of LLMContextAggregatorPair from an OpenAILLMContext.
|
||||
|
||||
Constructor keyword arguments for both the user and assistant aggregators can be provided.
|
||||
|
||||
NOTE: this method exists only for backward compatibility. New code
|
||||
should instead do::
|
||||
|
||||
context = LLMContext(...)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
Args:
|
||||
context: The LLM context to use.
|
||||
user_params: User aggregator parameters. Defaults to LLMUserAggregatorParams().
|
||||
assistant_params: Assistant aggregator parameters. Defaults to LLMAssistantAggregatorParams().
|
||||
|
||||
Returns:
|
||||
A pair of user and assistant context aggregators.
|
||||
"""
|
||||
context = LLMContext.from_openai_context(context)
|
||||
assistant_params.expect_stripped_words = False
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
@@ -1,448 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""This module implements Ultravox speech-to-text with a locally-loaded model."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import login
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import AsyncLLMEngine, SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Ultravox, you need to `pip install pipecat-ai[ultravox]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AudioBuffer:
|
||||
"""Buffer to collect audio frames before processing.
|
||||
|
||||
Manages the collection and state of audio frames during speech
|
||||
recording sessions, including timing and processing flags.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the audio buffer."""
|
||||
self.frames: List[AudioRawFrame] = []
|
||||
self.started_at: Optional[float] = None
|
||||
self.is_processing: bool = False
|
||||
|
||||
|
||||
class UltravoxModel:
|
||||
"""Model wrapper for the Ultravox multimodal model.
|
||||
|
||||
This class handles loading and running the Ultravox model for speech-to-text
|
||||
transcription using vLLM for efficient inference.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "fixie-ai/ultravox-v0_5-llama-3_1-8b"):
|
||||
"""Initialize the Ultravox model.
|
||||
|
||||
Args:
|
||||
model_name: The name or path of the Ultravox model to load.
|
||||
Defaults to "fixie-ai/ultravox-v0_5-llama-3_1-8b".
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self._initialize_engine()
|
||||
self._initialize_tokenizer()
|
||||
self.stop_token_ids = None
|
||||
|
||||
def _initialize_engine(self):
|
||||
"""Initialize the vLLM engine for inference."""
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=self.model_name,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=8192,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
def _initialize_tokenizer(self):
|
||||
"""Initialize the tokenizer for the model."""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
def format_prompt(self, messages: list):
|
||||
"""Format chat messages into a prompt for the model.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'.
|
||||
|
||||
Returns:
|
||||
str: Formatted prompt string ready for model input.
|
||||
"""
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: list,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 100,
|
||||
audio: np.ndarray = None,
|
||||
):
|
||||
"""Generate text from audio input using the model.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries for conversation context.
|
||||
temperature: Sampling temperature for generation randomness.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
audio: Audio data as numpy array in float32 format.
|
||||
|
||||
Yields:
|
||||
str: JSON chunks of the generated response in OpenAI format.
|
||||
"""
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature, max_tokens=max_tokens, stop_token_ids=self.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = {"audio": audio}
|
||||
inputs = {"prompt": self.format_prompt(messages), "multi_modal_data": mm_data}
|
||||
results_generator = self.engine.generate(inputs, sampling_params, str(time.time()))
|
||||
|
||||
previous_text = ""
|
||||
first_chunk = True
|
||||
|
||||
async for output in results_generator:
|
||||
prompt_output = output.outputs
|
||||
new_text = prompt_output[0].text[len(previous_text) :]
|
||||
previous_text = prompt_output[0].text
|
||||
|
||||
# Construct OpenAI-compatible chunk
|
||||
chunk = {
|
||||
"id": str(int(time.time() * 1000)),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": self.model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Include the role in the first chunk
|
||||
if first_chunk:
|
||||
chunk["choices"][0]["delta"]["role"] = "assistant"
|
||||
first_chunk = False
|
||||
|
||||
# Add new text to the delta if any
|
||||
if new_text:
|
||||
chunk["choices"][0]["delta"]["content"] = new_text
|
||||
|
||||
# Capture a finish reason if it's provided
|
||||
finish_reason = prompt_output[0].finish_reason or None
|
||||
if finish_reason and finish_reason != "none":
|
||||
chunk["choices"][0]["finish_reason"] = finish_reason
|
||||
|
||||
yield json.dumps(chunk)
|
||||
|
||||
|
||||
class UltravoxSTTService(AIService):
|
||||
"""Service to transcribe audio using the Ultravox multimodal model.
|
||||
|
||||
This service collects audio frames during speech and processes them with
|
||||
Ultravox to generate text transcriptions. It handles real-time audio
|
||||
buffering, model warm-up, and streaming text generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_name: str = "fixie-ai/ultravox-v0_5-llama-3_1-8b",
|
||||
hf_token: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the UltravoxSTTService.
|
||||
|
||||
Args:
|
||||
model_name: The Ultravox model to use. Defaults to
|
||||
"fixie-ai/ultravox-v0_5-llama-3_1-8b".
|
||||
hf_token: Hugging Face token for model access. If None, will try
|
||||
to use HF_TOKEN environment variable.
|
||||
temperature: Sampling temperature for generation. Defaults to 0.7.
|
||||
max_tokens: Maximum tokens to generate. Defaults to 100.
|
||||
**kwargs: Additional arguments passed to AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Authenticate with Hugging Face if token provided
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
elif os.environ.get("HF_TOKEN"):
|
||||
login(token=os.environ.get("HF_TOKEN"))
|
||||
else:
|
||||
logger.warning("No Hugging Face token provided. Model may not load correctly.")
|
||||
|
||||
# Initialize model
|
||||
self._model = UltravoxModel(model_name=model_name)
|
||||
|
||||
# Initialize service state
|
||||
self._buffer = AudioBuffer()
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._connection_active = False
|
||||
self._warm_up_duration_sec = 1
|
||||
|
||||
logger.info(f"Initialized UltravoxSTTService with model: {model_name}")
|
||||
|
||||
async def _warm_up_model(self):
|
||||
"""Warm up the model with silent audio to improve first inference performance.
|
||||
|
||||
This method generates a short segment of silent audio and runs it through
|
||||
the model to ensure the model is fully loaded and optimized for the first
|
||||
real inference request.
|
||||
"""
|
||||
logger.info("Warming up Ultravox model with silent audio...")
|
||||
|
||||
# Generate silent audio at 16kHz sample rate
|
||||
sample_rate = 16000
|
||||
silent_audio = self._generate_silent_audio(sample_rate, self._warm_up_duration_sec)
|
||||
|
||||
try:
|
||||
# Process the silent audio with the model
|
||||
messages = [{"role": "user", "content": "<|audio|>\n"}]
|
||||
warmup_generator = self._model.generate(
|
||||
messages=messages,
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
audio=silent_audio,
|
||||
)
|
||||
|
||||
# Consume the generator to actually run the inference
|
||||
async for _ in warmup_generator:
|
||||
pass
|
||||
|
||||
logger.info("Model warm-up completed successfully")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
|
||||
def _generate_silent_audio(self, sample_rate=16000, duration_sec=1.0):
|
||||
"""Generate silent audio as a numpy array.
|
||||
|
||||
Args:
|
||||
sample_rate: Sample rate in Hz
|
||||
duration_sec: Duration of silence in seconds
|
||||
|
||||
Returns:
|
||||
np.ndarray: Float32 array of zeros representing silent audio
|
||||
"""
|
||||
# Calculate number of samples
|
||||
num_samples = int(sample_rate * duration_sec)
|
||||
|
||||
# Create silent audio as float32 in the [-1.0, 1.0] range
|
||||
silent_audio = np.zeros(num_samples, dtype=np.float32)
|
||||
|
||||
logger.info(f"Generated {duration_sec}s of silent audio ({num_samples} samples)")
|
||||
return silent_audio
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Indicates whether this service can generate metrics.
|
||||
|
||||
Returns:
|
||||
bool: True, as this service supports metric generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Handle service start.
|
||||
|
||||
Starts the service, marks it as active, and performs model warm-up
|
||||
to ensure optimal performance for the first inference.
|
||||
|
||||
Args:
|
||||
frame: StartFrame that triggered this method.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._connection_active = True
|
||||
|
||||
await self._warm_up_model()
|
||||
|
||||
logger.info("UltravoxSTTService started")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Handle service stop.
|
||||
|
||||
Stops the service and marks it as inactive.
|
||||
|
||||
Args:
|
||||
frame: EndFrame that triggered this method.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
self._connection_active = False
|
||||
logger.info("UltravoxSTTService stopped")
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Handle service cancellation.
|
||||
|
||||
Cancels the service, clears any buffered audio, and marks it as inactive.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame that triggered this method.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
self._connection_active = False
|
||||
self._buffer = AudioBuffer()
|
||||
logger.info("UltravoxSTTService cancelled")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames.
|
||||
|
||||
This method collects audio frames during speech and processes them
|
||||
when speech ends to generate text transcriptions.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: Direction of the frame (input/output).
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
logger.info("Speech started")
|
||||
self._buffer = AudioBuffer()
|
||||
self._buffer.started_at = time.time()
|
||||
|
||||
elif isinstance(frame, AudioRawFrame) and self._buffer.started_at is not None:
|
||||
self._buffer.frames.append(frame)
|
||||
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
if self._buffer.frames and not self._buffer.is_processing:
|
||||
logger.info("Speech ended, processing buffer...")
|
||||
await self.process_generator(self._process_audio_buffer())
|
||||
return # Return early to avoid pushing None frame
|
||||
|
||||
# Only push the original frame if we haven't processed audio
|
||||
if frame is not None:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_audio_buffer(self) -> AsyncGenerator[Frame, None]:
|
||||
"""Process collected audio frames with Ultravox.
|
||||
|
||||
This method concatenates audio frames, processes them with the model,
|
||||
and yields the resulting text frames.
|
||||
|
||||
Yields:
|
||||
Frame: TextFrame containing the transcribed text
|
||||
"""
|
||||
try:
|
||||
self._buffer.is_processing = True
|
||||
|
||||
# Check if we have valid frames before processing
|
||||
if not self._buffer.frames:
|
||||
logger.warning("No audio frames to process")
|
||||
yield ErrorFrame("No audio frames to process")
|
||||
return
|
||||
|
||||
# Process audio frames
|
||||
audio_arrays = []
|
||||
for f in self._buffer.frames:
|
||||
if hasattr(f, "audio") and f.audio:
|
||||
# Handle bytes data - these are int16 PCM samples
|
||||
if isinstance(f.audio, bytes):
|
||||
try:
|
||||
# Convert bytes to int16 array
|
||||
arr = np.frombuffer(f.audio, dtype=np.int16)
|
||||
if arr.size > 0: # Check if array is not empty
|
||||
audio_arrays.append(arr)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
# Handle numpy array data
|
||||
elif isinstance(f.audio, np.ndarray):
|
||||
if f.audio.size > 0: # Check if array is not empty
|
||||
# Ensure it's int16 data
|
||||
if f.audio.dtype != np.int16:
|
||||
logger.info(f"Converting array from {f.audio.dtype} to int16")
|
||||
audio_arrays.append(f.audio.astype(np.int16))
|
||||
else:
|
||||
audio_arrays.append(f.audio)
|
||||
|
||||
# Only proceed if we have valid audio arrays
|
||||
if not audio_arrays:
|
||||
logger.warning("No valid audio data found in frames")
|
||||
yield ErrorFrame("No valid audio data found in frames")
|
||||
return
|
||||
|
||||
# Concatenate audio frames - all should be int16 now
|
||||
audio_data = np.concatenate(audio_arrays)
|
||||
|
||||
audio_int16 = audio_data # Already in int16 format
|
||||
# Save int16 audio
|
||||
|
||||
# Convert int16 to float32 and normalize for model input
|
||||
audio_float32 = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
# Generate text using the model
|
||||
if self._model:
|
||||
try:
|
||||
logger.info("Generating text from audio using model...")
|
||||
|
||||
# Start metrics tracking
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
yield LLMFullResponseStartFrame()
|
||||
|
||||
async for response in self._model.generate(
|
||||
messages=[{"role": "user", "content": "<|audio|>\n"}],
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
audio=audio_float32,
|
||||
):
|
||||
# Stop TTFB metrics after first response
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
chunk = json.loads(response)
|
||||
if "choices" in chunk and len(chunk["choices"]) > 0:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "content" in delta:
|
||||
new_text = delta["content"]
|
||||
if new_text:
|
||||
yield LLMTextFrame(text=new_text)
|
||||
|
||||
# Stop processing metrics after completion
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
yield LLMFullResponseEndFrame()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
else:
|
||||
yield ErrorFrame("No model available for text generation")
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(f"Error processing audio: {str(e)}")
|
||||
finally:
|
||||
self._buffer.is_processing = False
|
||||
self._buffer.frames = []
|
||||
self._buffer.started_at = None
|
||||
@@ -12,7 +12,7 @@ from typing import Awaitable, Callable, Optional
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
from websockets.protocol import State
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame
|
||||
@@ -137,6 +137,10 @@ class WebsocketService(ABC):
|
||||
# Normal closure, don't retry
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
break
|
||||
except ConnectionClosedError as e:
|
||||
# Error closure, don't retry
|
||||
logger.warning(f"{self} connection closed, but with an error: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.audio.dtmf.utils import load_dtmf_audio
|
||||
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
|
||||
from pipecat.audio.utils import create_stream_resampler, is_silence
|
||||
from pipecat.frames.frames import (
|
||||
AssistantImageRawFrame,
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -335,6 +336,10 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await sender.handle_audio_frame(frame)
|
||||
elif isinstance(frame, (OutputImageRawFrame, SpriteFrame)):
|
||||
await sender.handle_image_frame(frame)
|
||||
if isinstance(frame, AssistantImageRawFrame):
|
||||
# This will push it further, to be handled by the assistant
|
||||
# aggregator, say
|
||||
await sender.handle_sync_frame(frame)
|
||||
elif isinstance(frame, MixerControlFrame):
|
||||
await sender.handle_mixer_control_frame(frame)
|
||||
elif frame.pts:
|
||||
@@ -753,7 +758,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._handle_frame(frame)
|
||||
|
||||
# If we are not able to write to the transport we shouldn't
|
||||
# pushb downstream.
|
||||
# push downstream.
|
||||
push_downstream = True
|
||||
|
||||
# Try to send audio to the transport.
|
||||
|
||||
@@ -16,7 +16,7 @@ The module consists of three main components:
|
||||
- HeyGenTransport: Main transport implementation that coordinates input/output transports
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -36,8 +36,9 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.heygen.api import NewSessionRequest
|
||||
from pipecat.services.heygen.client import HeyGenCallbacks, HeyGenClient
|
||||
from pipecat.services.heygen.api_interactive_avatar import NewSessionRequest
|
||||
from pipecat.services.heygen.api_liveavatar import LiveAvatarNewSessionRequest
|
||||
from pipecat.services.heygen.client import HeyGenCallbacks, HeyGenClient, ServiceType
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
@@ -297,10 +298,8 @@ class HeyGenTransport(BaseTransport):
|
||||
params: HeyGenParams = HeyGenParams(),
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
session_request: NewSessionRequest = NewSessionRequest(
|
||||
avatar_id="Shawn_Therapist_public",
|
||||
version="v2",
|
||||
),
|
||||
session_request: Optional[Union[LiveAvatarNewSessionRequest, NewSessionRequest]] = None,
|
||||
service_type: Optional[ServiceType] = None,
|
||||
):
|
||||
"""Initialize the HeyGen transport.
|
||||
|
||||
@@ -313,7 +312,8 @@ class HeyGenTransport(BaseTransport):
|
||||
params: HeyGen-specific configuration parameters (default: HeyGenParams())
|
||||
input_name: Optional custom name for the input transport
|
||||
output_name: Optional custom name for the output transport
|
||||
session_request: Configuration for the HeyGen session (default: uses Shawn_Therapist_public avatar)
|
||||
session_request: Configuration for the HeyGen session
|
||||
service_type: Service type for the avatar session
|
||||
|
||||
Note:
|
||||
The transport will automatically join the same virtual room as the HeyGen Avatar
|
||||
@@ -326,6 +326,7 @@ class HeyGenTransport(BaseTransport):
|
||||
session=session,
|
||||
params=params,
|
||||
session_request=session_request,
|
||||
service_type=service_type,
|
||||
callbacks=HeyGenCallbacks(
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
|
||||
@@ -160,7 +160,7 @@ class SmallWebRTCRequestHandler:
|
||||
self,
|
||||
request: SmallWebRTCRequest,
|
||||
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
|
||||
) -> None:
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Handle a SmallWebRTC request and resolve the pending answer.
|
||||
|
||||
This method will:
|
||||
@@ -176,6 +176,10 @@ class SmallWebRTCRequestHandler:
|
||||
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
|
||||
asynchronous callback function that is invoked with the WebRTC connection.
|
||||
|
||||
Returns:
|
||||
Dictionary containing SDP answer, type, and peer connection ID,
|
||||
or None if no answer is available.
|
||||
|
||||
Raises:
|
||||
HTTPException: If connection mode constraints are violated
|
||||
Exception: Any exception raised during request handling or callback execution
|
||||
|
||||
@@ -12,6 +12,7 @@ comprehensive monitoring and cleanup capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Dict, Optional, Sequence
|
||||
@@ -162,7 +163,9 @@ class TaskManager(BaseTaskManager):
|
||||
# Re-raise the exception to ensure the task is cancelled.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{name}: unexpected exception: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.error(f"{name} unexpected exception ({last.filename}:{last.lineno}): {e}")
|
||||
|
||||
if not self._params:
|
||||
raise Exception("TaskManager is not setup: unable to get event loop")
|
||||
@@ -197,9 +200,17 @@ class TaskManager(BaseTaskManager):
|
||||
# Here are sure the task is cancelled properly.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{name}: unexpected exception while cancelling task: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.error(
|
||||
f"{name} unexpected exception while cancelling task ({last.filename}:{last.lineno}): {e}"
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.critical(f"{name}: fatal base exception while cancelling task: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.critical(
|
||||
f"{name} fatal base exception while cancelling task ({last.filename}:{last.lineno}): {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def current_tasks(self) -> Sequence[asyncio.Task]:
|
||||
|
||||
@@ -13,6 +13,7 @@ and async cleanup for all Pipecat components.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -187,7 +188,11 @@ class BaseObject(ABC):
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in event handler {event_name}: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.error(
|
||||
f"Uncaught exception in event handler '{event_name}' ({last.filename}:{last.lineno}): {e}"
|
||||
)
|
||||
|
||||
def _event_task_finished(self, task: asyncio.Task):
|
||||
"""Clean up completed event handler tasks.
|
||||
|
||||
@@ -203,7 +203,7 @@ def parse_start_end_tags(
|
||||
class TextPartForConcatenation:
|
||||
"""Class representing a part of text for concatenation with concatenate_aggregated_text.
|
||||
|
||||
Attributes:
|
||||
Parameters:
|
||||
text: The text content.
|
||||
includes_inter_part_spaces: Whether any necessary inter-frame
|
||||
(leading/trailing) spaces are already included in the text.
|
||||
|
||||
@@ -26,15 +26,15 @@ class MatchAction(Enum):
|
||||
|
||||
Parameters:
|
||||
REMOVE: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
KEEP: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
AGGREGATE: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
@@ -133,17 +133,15 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
|
||||
Args:
|
||||
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
action: What to do when a complete pattern is matched:
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
action: What to do when a complete pattern is matched.
|
||||
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as normal text. This allows you to register handlers for the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate aggregation object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
@@ -40,7 +40,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
return Aggregation(text=self._text.strip(" "), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text and yield completed sentences.
|
||||
@@ -97,7 +97,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
# NLTK confirmed a sentence - return it
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return Aggregation(text=result, type=AggregationType.SENTENCE)
|
||||
return Aggregation(text=result.strip(" "), type=AggregationType.SENTENCE)
|
||||
# No sentence found - keep accumulating
|
||||
return None
|
||||
# Still whitespace, keep waiting
|
||||
@@ -123,7 +123,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
# Return whatever we have in the buffer
|
||||
result = self._text
|
||||
await self.reset()
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
return Aggregation(text=result.strip(" "), type=AggregationType.SENTENCE)
|
||||
return None
|
||||
|
||||
async def handle_interruption(self):
|
||||
|
||||
@@ -483,7 +483,9 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
# Add all available attributes to the span
|
||||
attribute_kwargs = {
|
||||
"service_name": service_class_name,
|
||||
"model": getattr(self, "model_name", "unknown"),
|
||||
"model": getattr(
|
||||
self, getattr(self, "_full_model_name", "model_name"), "unknown"
|
||||
),
|
||||
"stream": True, # Most LLM services use streaming
|
||||
"parameters": params,
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user