Compare commits

...

94 Commits

Author SHA1 Message Date
Mark Backman
4cefe1357c Merge pull request #3201 from pipecat-ai/changelog-0.0.97
Release 0.0.97 - Changelog Update
2025-12-05 18:49:15 -05:00
markbackman
4df0a9bf73 Update changelog for version 0.0.97 2025-12-05 18:47:21 -05:00
Mark Backman
9ef139d020 Merge pull request #3200 from pipecat-ai/mb/improve-changelog-template
Fix newlines between sections in changlelog template
2025-12-05 18:42:52 -05:00
Mark Backman
9103d4ae05 Fix newlines between sections in changlelog template 2025-12-05 18:40:49 -05:00
Aleix Conchillo Flaqué
bd63b6cefa Merge pull request #3198 from pipecat-ai/aleix/examples-14i-new-model
examples(foundational): update 14i-fireworks with new serverless model
2025-12-05 15:33:12 -08:00
Aleix Conchillo Flaqué
4d03270bc3 examples(foundational): update 14i-fireworks with new serverless model 2025-12-05 15:31:29 -08:00
Mark Backman
6aee72c5b4 Merge pull request #3196 from pipecat-ai/mb/docs-cleanup-prep-0.0.97
Docs cleanup before 0.0.97 release
2025-12-05 15:16:36 -05:00
Mark Backman
8d62cfb1b6 Merge pull request #3195 from ivaaan/add-hume-header
Add tracking headers to Hume service
2025-12-05 14:50:18 -05:00
ivaaan
41214236ab add changelog 2025-12-05 20:47:04 +01:00
Mark Backman
b25963a63b Docs cleanup before 0.0.97 release 2025-12-05 14:19:26 -05:00
ivaaan
8c6ef21d84 add stop, cancel 2025-12-05 20:13:58 +01:00
ivaaan
0ffaa09c95 add tracking headers to Hume service 2025-12-05 19:00:47 +01:00
Aleix Conchillo Flaqué
f6e31b7e89 Merge pull request #3185 from pipecat-ai/fix/websocket-service-cancelled-error-handling
fix(websocket): handle CancelledError to prevent reconnection on shutdown
2025-12-05 09:25:49 -08:00
Aleix Conchillo Flaqué
48422dd442 WebsocketService: avoid reconnection on shutdown 2025-12-05 09:03:04 -08:00
Vanessa Pyne
fed6a8b669 Merge pull request #3187 from pipecat-ai/vp-mcp-filter-followup
add mcp filter example and changelog
2025-12-05 10:58:19 -06:00
vipyne
82e0253a62 add mcp filter example and changelog 2025-12-05 10:56:59 -06:00
Vanessa Pyne
a7f26dca60 Merge pull request #3152 from RuiDaniel/mcp_client_filters
Add filters to MCP Client
2025-12-05 10:50:27 -06:00
Vanessa Pyne
459ef27f3f Merge pull request #3079 from pipecat-ai/vp-add-exact-model-version-function
set full model name for base openai models
2025-12-05 10:48:53 -06:00
Mark Backman
464cfa5ccb Merge pull request #3188 from pipecat-ai/mb/improve-changelog-process
Auto-generate changelog from fragments
2025-12-05 11:42:25 -05:00
Mark Backman
9289881a80 Remove 3120.added.md 2025-12-05 11:35:50 -05:00
Mark Backman
34033cd454 Add new changelog entries 2025-12-05 11:35:50 -05:00
Mark Backman
47c21c9579 Delete README.md in changelog 2025-12-05 11:35:50 -05:00
Mark Backman
3b0bcf0b66 Validate fragment types match the expected types 2025-12-05 11:35:50 -05:00
Mark Backman
c4a8308027 Fail when no changelog fragments are available 2025-12-05 11:35:50 -05:00
Mark Backman
e9f76dcaf2 Set the date automatically when the workflow runs, leaving an optional override 2025-12-05 11:35:50 -05:00
Mark Backman
21b2229b2b Auto-generate changelog from fragments 2025-12-05 11:35:49 -05:00
Aleix Conchillo Flaqué
11aa9c9e68 update CHANGELOG, remove wait_for_all 2025-12-05 08:34:07 -08:00
Aleix Conchillo Flaqué
9f4680e9bd Merge pull request #3190 from pipecat-ai/aleix/no-need-wait-for-all
LLMService: let's not introduce wait_for_all for now
2025-12-05 08:31:44 -08:00
Aleix Conchillo Flaqué
04443a3820 LLMService: let's not introduce wait_for_all for now 2025-12-05 08:26:04 -08:00
Mark Backman
1571cc58ac Merge pull request #3192 from pipecat-ai/mb/cartesia-stt-timestamp
Add full transcript result for CartesiaSTTService
2025-12-05 10:37:06 -05:00
Mark Backman
dea80cf946 Add full transcript result for CartesiaSTTService 2025-12-05 10:25:46 -05:00
Mark Backman
91dec044c4 Merge pull request #3171 from LaurentMazare/gradium
Gradium integration.
2025-12-05 09:43:44 -05:00
laurent
8cf4267d87 Switch to a debug. 2025-12-05 15:37:17 +01:00
Mark Backman
0ee7cab6c6 Merge pull request #3184 from ashotbagh/feat/asyncai-multilingual-addons
Added new languages support for AsyncAI
2025-12-05 08:42:09 -05:00
Ashot
74c2039bfb Updated changelog. 2025-12-05 16:54:38 +04:00
Ashot
66088837cd Fixed defualt language issue in async tts 2025-12-05 16:51:05 +04:00
laurent
07ebf8534a Add the example. 2025-12-05 10:51:22 +01:00
laurent
fce4cfba15 Changelog update. 2025-12-05 10:46:01 +01:00
laurent
af52833ca0 Update the readme and env.example. 2025-12-05 10:44:30 +01:00
laurent
9fdf756375 Fix. 2025-12-05 10:38:35 +01:00
laurent
283bbb385c And remove the request-id. 2025-12-05 10:35:19 +01:00
laurent
8c6b2edb25 Various code review tweaks. 2025-12-05 10:33:48 +01:00
Laurent Mazare
6ab30f9b87 Apply suggestions from code review
Co-authored-by: Mark Backman <m.backman@gmail.com>
2025-12-05 10:25:47 +01:00
Aleix Conchillo Flaqué
3d93285bdf Merge pull request #3176 from pipecat-ai/aleix/exception-filename-line-number
log file name and line number when exception occurs
2025-12-04 11:08:32 -08:00
Aleix Conchillo Flaqué
7261cd28f2 log file name and line number when exception occurs 2025-12-04 11:06:45 -08:00
vipyne
33eeb8ce44 Use _full_model_name in llm trace if available 2025-12-04 11:54:45 -06:00
vipyne
ebda94ca98 set full model name for base openai models 2025-12-04 11:54:45 -06:00
Mark Backman
40b17cff8f Merge pull request #3186 from pipecat-ai/mb/11labs-fix-metrics-tracking
fix: ElevenLabsTTSService character usage metrics
2025-12-04 12:36:39 -05:00
marcus-daily
7ba0ebba11 Smart Turn analyzer now uses the full context of the turn rather than just the audio since VAD last triggered (fixes #3094) 2025-12-04 16:40:08 +00:00
Mark Backman
b39087027c fix: ElevenLabsTTSService character usage metrics 2025-12-04 09:41:18 -05:00
Ashot
e65974c870 Added new languages support for AsyncAI 2025-12-04 16:15:28 +04:00
marcus-daily
b1e5d68d97 Updating changelog 2025-12-04 11:32:16 +00:00
marcus-daily
39bca074d7 Smart Turn v3.1 2025-12-04 11:32:16 +00:00
Aleix Conchillo Flaqué
b5e79f9dc5 Merge pull request #3181 from pipecat-ai/aleix/sync-to-utils-sync
move pipecat.sync to pipecat.utils.sync
2025-12-03 19:41:18 -08:00
Aleix Conchillo Flaqué
613b96819f Merge pull request #3180 from pipecat-ai/aleix/deepgram-tts-service-fix
DeepgramTTSService: fix websocket header logging
2025-12-03 19:40:43 -08:00
Mark Backman
57c24670ea Merge pull request #3132 from pipecat-ai/mb/normalize-llm-text-frame-output
Add split_text_by_spaces string util, normalize aggregator input
2025-12-03 22:05:14 -05:00
Mark Backman
d79dd94019 Make aggregate return an AsyncIterator, other clean up 2025-12-03 22:00:34 -05:00
Mark Backman
fa8e7458e1 Clean up 2025-12-03 22:00:04 -05:00
Mark Backman
4d66191963 fix: PatternPairAggregator to process patterns only once 2025-12-03 22:00:04 -05:00
Mark Backman
7e9d67002e SkipTagsAggregator and PatternPairAggregator now subclass SimpleTextAggregator 2025-12-03 22:00:04 -05:00
Mark Backman
ffbb6e5937 Update SimpleTextAggregator to handle character by character input, use a buffer to handle ambiguous EOS scenarios, and add a flush method to all aggregators 2025-12-03 22:00:02 -05:00
Mark Backman
535b85cf90 Add split_text_by_spaces string util 2025-12-03 21:55:30 -05:00
Aleix Conchillo Flaqué
8dc9872ed5 deprecate pipecat.sync package 2025-12-03 18:44:41 -08:00
Aleix Conchillo Flaqué
f37a53cc25 utils(sync): move sync to utils.sync 2025-12-03 18:20:12 -08:00
Aleix Conchillo Flaqué
9cce28c64c DeepgramTTSService: use websocket response headers for logging 2025-12-03 18:16:25 -08:00
Aleix Conchillo Flaqué
3ca94363ec Merge pull request #3168 from pipecat-ai/aleix/dont-override-skip-tts
LLMTextFrame: don't override skip_tts
2025-12-03 18:15:50 -08:00
Rpcd
9dd882ecf8 Update src/pipecat/services/mcp_service.py
Co-authored-by: Vanessa Pyne <vipyne@gmail.com>
2025-12-03 17:28:37 +00:00
Rpcd
0bbb14eb9b Update src/pipecat/services/mcp_service.py
Co-authored-by: Vanessa Pyne <vipyne@gmail.com>
2025-12-03 17:28:29 +00:00
Mark Backman
050f287ec4 Merge pull request #3072 from jjmaldonis/deepgram/add-deepgram-request-ids-to-debug-logs
deepgram: added request IDs to debug logs
2025-12-03 09:37:25 -05:00
Jason Maldonis
e6f5561785 updated changelog 2025-12-03 08:18:09 -06:00
Jason Maldonis
2df91f4b37 fixed linting 2025-12-03 08:09:16 -06:00
Jason Maldonis
7db49b9067 deepgram: added request IDs to debug logs
Deepgram request IDs are necessary for investigating behavior at the
request level. This commit adds DEBUG logs that print Deepgram request
IDs when using Deepgram's STT or TTS.
2025-12-03 08:09:13 -06:00
Vanessa Pyne
7c497bdc89 Merge pull request #3130 from pipecat-ai/vp-nvidia-docs
update nvidia services naming
2025-12-02 13:04:16 -06:00
vipyne
1aa4247d2b remove nim from pyproject.toml 2025-12-02 12:55:13 -06:00
laurent
1ffa9ff51f Gradium integration. 2025-12-02 13:34:51 +01:00
Rpcd
435b53f1a0 Update src/pipecat/services/mcp_service.py
Co-authored-by: Vanessa Pyne <vipyne@gmail.com>
2025-12-02 09:22:08 +00:00
Rpcd
406bdfad0d Update src/pipecat/services/mcp_service.py
Co-authored-by: Vanessa Pyne <vipyne@gmail.com>
2025-12-02 09:21:59 +00:00
vipyne
acba544e6f pr notes for nvidia service name change 2025-12-01 22:41:17 -06:00
vipyne
5d93c64ee5 typo fixes and uv.lock update 2025-12-01 22:41:17 -06:00
vipyne
de10bc8803 changelog for riva,nim -> nvidia name change 2025-12-01 22:41:17 -06:00
vipyne
36f5c1722d deprecate riva and nim service paths in favor of nvidia 2025-12-01 22:41:17 -06:00
vipyne
a8280522e5 examples: rename nvidia foundational examples 2025-12-01 22:41:17 -06:00
vipyne
05d65dfdd3 Update NVIDIA NIM and Riva services to Nvidia
- pip install pipecat-ai[nim]
- pip install pipecat-ai[riva]

+ pip install pipecat-ai[nvidia]

and

- from pipecat.services.nim.llm import NimLLMService
+ from pipecat.services.nvidia.llm import NvidiaLLMService

- from pipecat.services.riva.stt import RivaSTTService
+ from pipecat.services.nvidia.stt import NvidiaSTTService

- from pipecat.services.riva.tts import RivaTTSService
+ from pipecat.services.nvidia.tts import NvidiaTTSService
2025-12-01 22:41:17 -06:00
Aleix Conchillo Flaqué
a3962e3b47 LLMTextFrame: don't override skip_tts 2025-12-01 18:37:07 -08:00
Aleix Conchillo Flaqué
cd231cf829 Merge pull request #3120 from pipecat-ai/aleix/function-calls-wait-for-all
allow waiting for all function calls to complete
2025-12-01 18:35:53 -08:00
Aleix Conchillo Flaqué
9fafc1692d update uv.lock 2025-12-01 18:32:00 -08:00
Aleix Conchillo Flaqué
7648d0436c examples(19): linting 2025-12-01 18:30:34 -08:00
Aleix Conchillo Flaqué
bff8747e38 LLMService: allow waiting for all function calls to complete 2025-12-01 18:30:25 -08:00
Mark Backman
d227c0c097 Merge pull request #3155 from pipecat-ai/mb/fix-sarvam-tts-not-flushing
fix: flush audio in SarvamTTSService
2025-12-01 17:22:33 -05:00
Mark Backman
9ccde60521 fix: flush audio in SarvamTTSService 2025-12-01 17:18:34 -05:00
Mark Backman
b84a40666c Merge pull request #3156 from pipecat-ai/mb/deepgram-stt-stopped-frame
fix: DeepgramTTSService, let the base class push TTSStoppedFrame
2025-12-01 17:18:19 -05:00
Mark Backman
e72b135a4c fix: DeepgramTTSService, let the base class push TTSStoppedFrame 2025-12-01 17:15:51 -05:00
RuiDaniel
7961f8a664 same behaviour on error 2025-11-25 18:35:59 +00:00
RuiDaniel
4ca143e8af add mcp filters to client 2025-11-25 18:27:22 +00:00
75 changed files with 3009 additions and 1485 deletions

174
.github/workflows/generate-changelog.yml vendored Normal file
View File

@@ -0,0 +1,174 @@
name: Generate Changelog for Release
on:
workflow_dispatch:
inputs:
version:
description: "Release version (e.g., 0.0.97)"
required: true
type: string
date:
description: "Release date (YYYY-MM-DD format, defaults to today)"
required: false
type: string
default: ""
permissions:
contents: write
pull-requests: write
jobs:
generate-changelog:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
enable-cache: true
- name: Install dependencies
run: |
uv sync --group dev
- name: Set release date
id: set_date
run: |
if [ -z "${{ inputs.date }}" ]; then
RELEASE_DATE=$(date +%Y-%m-%d)
echo "Using today's date: $RELEASE_DATE"
else
RELEASE_DATE="${{ inputs.date }}"
echo "Using provided date: $RELEASE_DATE"
fi
echo "release_date=$RELEASE_DATE" >> $GITHUB_OUTPUT
- name: Validate inputs
run: |
# Validate version format (basic check)
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+.*$ ]]; then
echo "Error: Version must be in format X.Y.Z (e.g., 0.0.97)"
exit 1
fi
# Validate date format if provided
if [ -n "${{ inputs.date }}" ]; then
if ! date -d "${{ inputs.date }}" >/dev/null 2>&1; then
# Try macOS date format
if ! date -j -f "%Y-%m-%d" "${{ inputs.date }}" >/dev/null 2>&1; then
echo "Error: Date must be in YYYY-MM-DD format (e.g., 2025-12-04)"
exit 1
fi
fi
fi
- name: Check for changelog fragments
id: check_fragments
run: |
FRAGMENT_COUNT=$(find changelog -name "*.md" ! -name "_template.md.j2" | wc -l | tr -d ' ')
echo "fragment_count=$FRAGMENT_COUNT" >> $GITHUB_OUTPUT
if [ "$FRAGMENT_COUNT" -eq "0" ]; then
echo "❌ Error: No changelog fragments found in changelog/"
echo ""
echo "Cannot create a release without changelog entries."
echo "Add changelog fragments to the changelog/ directory (e.g., 1234.added.md) and try again."
exit 1
fi
# Validate fragment types
VALID_TYPES="added changed deprecated removed fixed security"
INVALID_FRAGMENTS=""
for file in changelog/*.md; do
# Skip template
if [[ "$file" == "changelog/_template.md.j2" ]]; then
continue
fi
# Extract type from filename (e.g., 1234.added.md -> added)
filename=$(basename "$file")
# Handle both 1234.added.md and 1234.added.2.md patterns
type=$(echo "$filename" | sed -E 's/^[0-9]+\.([a-z]+)(\.[0-9]+)?\.md$/\1/')
# Check if type is valid
if ! echo "$VALID_TYPES" | grep -wq "$type"; then
INVALID_FRAGMENTS="$INVALID_FRAGMENTS\n - $filename (type: '$type')"
fi
done
if [ -n "$INVALID_FRAGMENTS" ]; then
echo "❌ Error: Invalid changelog fragment types found:"
echo -e "$INVALID_FRAGMENTS"
echo ""
echo "Valid types are: $VALID_TYPES"
echo "Example: 1234.added.md, 5678.fixed.md"
exit 1
fi
echo "✓ Found $FRAGMENT_COUNT changelog fragment(s)"
echo "has_fragments=true" >> $GITHUB_OUTPUT
- name: Preview changelog
run: |
echo "## Preview of changelog for version ${{ inputs.version }}"
echo ""
uv run towncrier build --draft --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}"
- name: Build changelog
run: |
uv run towncrier build --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}" --yes
- name: Create Pull Request
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: "Update changelog for version ${{ inputs.version }}"
title: "Release ${{ inputs.version }} - Changelog Update"
body: |
## Changelog Update for Release ${{ inputs.version }}
This PR updates the CHANGELOG.md with all changes for version **${{ inputs.version }}**.
### Summary
- **Version:** ${{ inputs.version }}
- **Date:** ${{ steps.set_date.outputs.release_date }}
- **Fragments processed:** ${{ steps.check_fragments.outputs.fragment_count }}
### What this PR does
- ✅ Adds new release section to CHANGELOG.md
- ✅ Removes processed changelog fragments
- ✅ Ready to merge for release
### Next Steps
1. Review the changelog entries below
2. Make any necessary edits to CHANGELOG.md if needed
3. Merge this PR
4. Continue with your release process
---
<details>
<summary>📋 Preview of changes</summary>
The changelog has been updated with entries from the following fragments:
```bash
${{ steps.check_fragments.outputs.fragment_count }} fragments processed
```
</details>
branch: changelog-${{ inputs.version }}
delete-branch: true
labels: |
changelog
release

View File

@@ -5,25 +5,115 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
<!-- towncrier release notes start -->
## [0.0.97] - 2025-12-05
### Added
- Added new Gradium services, `GradiumSTTService` and `GradiumTTSService`, for
speech-to-text and text-to-speech functionality using Gradium's API.
- Additions for `AsyncAITTSService` and `AsyncAIHttpTTSService`:
- Added new `languages`: `pt`, `nl`, `ar`, `ru`, `ro`, `ja`, `he`, `hy`,
`tr`, `hi`, `zh`.
- Updated the default model to `asyncflow_multilingual_v1.0` for improved
accuracy and broader language coverage.
- Added optional tool and tool output filters for MCP services.
### Changed
- Updated Deepgram logging to include Deepgram request IDs for improved
debugging.
- Text Aggregation Improvements:
- **Breaking Change**: `BaseTextAggregator.aggregate()` now returns
`AsyncIterator[Aggregation]` instead of `Optional[Aggregation]`. This
enables the aggregator to return multiple results based on the provided
text.
- Refactored text aggregators to use inheritance: `SkipTagsAggregator` and
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
the base class's sentence detection logic.
- Improved interruption handling to prevent bots from repeating themselves. LLM
services that return multiple sentences in a single response (e.g.,
`GoogleLLMService`) are now split into individual sentences before being sent
to TTS. This ensures interruptions occur at sentence boundaries, preventing
the bot from repeating content after being interrupted during long responses.
- Updated `AICFilter` to use Quail STT as the default model
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
interaction (e.g., voice agents, speech-to-text) and operates at a native
sample rate of 16 kHz with fixed enhancement parameters.
- If an unexpected exception is caught, or if `FrameProcessor.push_error()` is
called with an exception, the file name and line number where the exception
occured are now logged.
- Updated Smart Turn model weights to v3.1.
- Smart Turn analyzer now uses the full context of the turn rather than just
the audio since VAD last triggered.
- Updated `CartesiaSTTService` to return the full transcription `result` in the
`TranscriptionFrame` and `InterimTranscriptionFrame`. This provides access to
word timestamp data.
- `HumeTTSService` changes:
- Added tracking headers (`X-Hume-Client-Name` and `X-Hume-Client-Version`)
to all requests made by `HumeTTSService` to the Hume API for better usage
tracking and analytics.
- Added `stop()` and `cancel()` cleanup methods to `HumeTTSService` to
properly close the HTTP client and prevent resource leaks.
### Deprecated
- NVIDIA Services name changes (all functionality is unchanged):
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
- `RivaSTTService` is now deprecated, use `NvidiaSTTService` instead.
- `RivaTTSService` is now deprecated, use `NvidiaTTSService` instead.
- Use `uv pip install pipecat-ai[nvidia]` instead of
`uv pip install pipecat-ai[riva]`
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
has any effect. Noise gating is now handled automatically by the AIC VAD
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
### Fixed
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was
always set to `us-east-1` when providing an AWS_REGION env var.
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
multiple times for `KEEP` or `AGGREGATE` patterns.
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was always
set to `us-east-1` when providing an AWS_REGION env var.
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
spoken. Now, audio is flushed when the TTS services receives the
`LLMFullResponseEndFrame` or `EndFrame`.
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
incorrectly pushed after a functional call. This caused an issue with the
voice-ui-kit's conversational panel rending of the LLM output after a
function call.
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
services.
- Fixed an issue that caused `WebsocketService` instances to attempt
reconnection during shutdown.
- Fixed an issue in `ElevenLabsTTSService` where character usage metrics were
only reported on the first TTS generation per turn.
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃

View File

@@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel
**Examples:**
- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py)
- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py)
- [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py)
#### Key requirements:

View File

@@ -17,24 +17,121 @@ We welcome contributions of all kinds! Your help is appreciated. Follow these st
git checkout -b your-branch-name
```
4. **Make your changes**: Edit or add files as necessary.
5. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
6. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
5. **Add a changelog entry**: Create a changelog fragment file (see [Changelog Entries](#changelog-entries) below).
6. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
7. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
```bash
git commit -m "Description of your changes"
```
7. **Push your changes**: Push your branch to your forked repository.
8. **Push your changes**: Push your branch to your forked repository.
```bash
git push origin your-branch-name
```
8. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
9. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
> Important: Describe the changes you've made clearly!
Our maintainers will review your PR, and once everything is good, your contributions will be merged!
## Changelog Entries
Every pull request that makes a user-facing change should include a changelog entry. We use a changelog fragment system to avoid merge conflicts.
### Creating a Changelog Fragment
1. Create a new file in the `changelog/` directory with this naming pattern:
```
<PR_number>.<type>.md
```
2. Choose the appropriate type:
- `added.md` - New features
- `changed.md` - Changes in existing functionality
- `deprecated.md` - Soon-to-be removed features
- `removed.md` - Removed features
- `fixed.md` - Bug fixes
- `security.md` - Security fixes
3. Write your changelog entry as a Markdown bullet point. Include the `-` at the start:
**Example files:**
`changelog/1234.added.md`:
```markdown
- Added support for Anthropic Claude 3.5 Sonnet with improved streaming performance.
```
`changelog/5678.fixed.md`:
```markdown
- Fixed an issue where audio frames were dropped during high-load scenarios.
```
**For entries with nested bullets:**
`changelog/1234.changed.md`:
```markdown
- Updated service configuration:
- Changed default timeout to 30 seconds
- Added retry logic for failed connections
```
### Multiple Changes in One PR
**Different types of changes:** Create separate fragment files for each type:
```
changelog/1234.added.md
changelog/1234.fixed.md
```
**Multiple changes of the same type:** Create numbered fragment files:
```
changelog/1234.changed.md
changelog/1234.changed.2.md
```
**Related changes:** Use nested bullets in a single fragment:
```markdown
- Updated service configuration:
- Changed default timeout to 30 seconds
- Added retry logic for failed connections
```
**Rule of thumb:** One logical change per fragment file. If changes are unrelated, use separate files.
### Preview Your Changes
To see what your changelog entry will look like:
```bash
towncrier build --draft --version Unreleased
```
This won't modify any files, just show you a preview.
### When to Skip Changelog Entries
You can skip adding a changelog entry for:
- Documentation-only changes
- Internal refactoring with no user-facing impact
- Test-only changes
- CI/build configuration changes
If you're unsure whether your change needs a changelog entry, ask in your PR!
## Dependency Management
This project uses [uv](https://docs.astral.sh/uv/) for dependency management. The `uv.lock` file is committed to ensure reproducible builds.

View File

@@ -74,9 +74,9 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
| Category | Services |
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |

16
changelog/_template.md.j2 Normal file
View File

@@ -0,0 +1,16 @@
{% for section, _ in sections.items() %}
{% if sections[section] %}
{% for category, val in definitions.items() if category in sections[section]%}
### {{ definitions[category]['name'] }}
{% for text, values in sections[section][category].items() %}
{{ text }}
{% endfor %}
{% endfor %}
{% else %}
No significant changes.
{% endif %}
{% endfor %}

View File

@@ -119,7 +119,6 @@ def import_core_modules():
"pipecat.observers",
"pipecat.runner",
"pipecat.serializers",
"pipecat.sync",
"pipecat.transcriptions",
"pipecat.utils",
]

View File

@@ -30,7 +30,6 @@ Quick Links
Runner <api/pipecat.runner>
Serializers <api/pipecat.serializers>
Services <api/pipecat.services>
Sync <api/pipecat.sync>
Transcriptions <api/pipecat.transcriptions>
Transports <api/pipecat.transports>
Utils <api/pipecat.utils>
Utils <api/pipecat.utils>

View File

@@ -73,6 +73,9 @@ GOOGLE_CLOUD_PROJECT_ID=...
GOOGLE_CLOUD_LOCATION=...
GOOGLE_TEST_CREDENTIALS=...
# Gradium
GRAPDIUM_API_KEY=...
# Grok
GROK_API_KEY=...
@@ -191,4 +194,4 @@ TWILIO_AUTH_TOKEN=...
WHATSAPP_TOKEN=...
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN=...
WHATSAPP_PHONE_NUMBER_ID=...
WHATSAPP_APP_SECRET=...
WHATSAPP_APP_SECRET=...

View File

@@ -15,7 +15,7 @@ from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.riva.tts import FastPitchTTSService
from pipecat.services.nvidia.tts import NvidiaTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -36,7 +36,7 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
task = PipelineTask(
Pipeline([tts, transport.output()]),

View File

@@ -0,0 +1,127 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.gradium.stt import GradiumSTTService
from pipecat.services.gradium.tts import GradiumTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
tts = GradiumTTSService(
api_key=os.getenv("GRADIUM_API_KEY"),
voice_id="YTpq7expH9539ERJ",
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
},
]
context = LLMContext(messages)
context_aggregator = LLMContextAggregatorPair(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -22,9 +22,9 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.nim.llm import NimLLMService
from pipecat.services.riva.stt import RivaSTTService
from pipecat.services.riva.tts import RivaTTSService
from pipecat.services.nvidia.llm import NvidiaLLMService
from pipecat.services.nvidia.stt import NvidiaSTTService
from pipecat.services.nvidia.tts import NvidiaTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -59,11 +59,13 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = RivaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
llm = NvidiaLLMService(
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
)
tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
messages = [
{

View File

@@ -76,7 +76,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
llm = FireworksLLMService(
api_key=os.getenv("FIREWORKS_API_KEY"),
model="accounts/fireworks/models/llama-v3p1-405b-instruct",
model="accounts/fireworks/models/gpt-oss-20b",
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.

View File

@@ -27,7 +27,7 @@ from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.nim.llm import NimLLMService
from pipecat.services.nvidia.llm import NvidiaLLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -75,11 +75,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# text_filters=[MarkdownTextFilter()],
)
llm = NimLLMService(
llm = NvidiaLLMService(
api_key=os.getenv("NVIDIA_API_KEY"),
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
# Recommended when turning thinking off
params=NimLLMService.InputParams(temperature=0.0),
params=NvidiaLLMService.InputParams(temperature=0.0),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.

View File

@@ -14,20 +14,13 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
LLMRunFrame,
LLMSetToolsFrame,
LLMUpdateSettingsFrame,
TranscriptionMessage,
)
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.runner.types import RunnerArguments

View File

@@ -19,7 +19,6 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport

View File

@@ -28,10 +28,10 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import LLMService
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.event_notifier import EventNotifier
load_dotenv(override=True)

View File

@@ -45,11 +45,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams, LLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)

View File

@@ -46,11 +46,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams, LLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)

View File

@@ -47,11 +47,11 @@ from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.llm_service import LLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)

View File

@@ -64,11 +64,14 @@ class UrlToImageProcessor(FrameProcessor):
await self.push_frame(frame, direction)
def extract_url(self, text: str):
data = json.loads(text)
if "artObject" in data:
return data["artObject"]["webImage"]["url"]
if "artworks" in data and len(data["artworks"]):
return data["artworks"][0]["webImage"]["url"]
try:
data = json.loads(text)
if "artObject" in data:
return data["artObject"]["webImage"]["url"]
if "artworks" in data and len(data["artworks"]):
return data["artworks"][0]["webImage"]["url"]
except:
pass
return None
@@ -88,6 +91,23 @@ class UrlToImageProcessor(FrameProcessor):
logger.error(error_msg)
# full list of tools available from rijksmuseum MCP:
# - get_artwork_details
# - get_artwork_image
# - get_user_sets
# - get_user_set_details
# - open_image_in_browser
# - get_artist_timeline
mcp_tools_filter = ["get_artwork_details", "get_artwork_image", "open_image_in_browser"]
def open_image_output_filter(output: str):
pattern = r"Successfully opened image in browser: "
text_to_print = re.sub(pattern, "", output)
print(f"🖼️ link to high resolution artwork: {text_to_print}")
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -136,7 +156,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# https://github.com/r-huijts/rijksmuseum-mcp
args=["-y", "mcp-server-rijksmuseum"],
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
)
),
# Optional
tools_filter=mcp_tools_filter, # Optional
tools_output_filters={"open_image_in_browser": open_image_output_filter},
)
except Exception as e:
logger.error(f"error setting up mcp")

View File

@@ -67,13 +67,14 @@ class UrlToImageProcessor(FrameProcessor):
await self.push_frame(frame, direction)
def extract_url(self, text: str):
data = json.loads(text)
if "artObject" in data:
return data["artObject"]["webImage"]["url"]
if "artworks" in data and len(data["artworks"]):
return data["artworks"][0]["webImage"]["url"]
return None
try:
data = json.loads(text)
if "artObject" in data:
return data["artObject"]["webImage"]["url"]
if "artworks" in data and len(data["artworks"]):
return data["artworks"][0]["webImage"]["url"]
except:
pass
async def run_image_process(self, image_url: str):
try:

View File

@@ -63,6 +63,7 @@ fireworks = []
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
gladia = [ "pipecat-ai[websockets-base]" ]
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ]
gradium = [ "pipecat-ai[websockets-base]" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]
@@ -83,8 +84,8 @@ mistral = []
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
neuphonic = [ "pipecat-ai[websockets-base]" ]
nim = []
noisereduce = [ "noisereduce~=3.0.3" ]
nvidia = [ "nvidia-riva-client~=2.21.1" ]
openai = [ "pipecat-ai[websockets-base]" ]
openpipe = [ "openpipe>=4.50.0,<6" ]
openrouter = []
@@ -93,7 +94,7 @@ playht = [ "pipecat-ai[websockets-base]" ]
qwen = []
remote-smart-turn = []
rime = [ "pipecat-ai[websockets-base]" ]
riva = [ "nvidia-riva-client~=2.21.1" ]
riva = [ "pipecat-ai[nvidia]" ]
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
sambanova = []
@@ -129,6 +130,7 @@ dev = [
"setuptools~=78.1.1",
"setuptools_scm~=8.3.1",
"python-dotenv>=1.0.1,<2.0.0",
"towncrier~=25.8.0",
]
docs = [
@@ -159,7 +161,7 @@ where = ["src"]
"src/pipecat/audio/dtmf/dtmf-star.wav",
]
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.1-cpu.onnx"]
[tool.pytest.ini_options]
addopts = "--verbose"
@@ -206,3 +208,44 @@ convention = "google"
command_line = "--module pytest"
source = ["src"]
omit = ["*/tests/*"]
[tool.towncrier]
package = "pipecat"
package_dir = "src"
filename = "CHANGELOG.md"
directory = "changelog"
start_string = "<!-- towncrier release notes start -->\n"
template = "changelog/_template.md.j2"
title_format = "## [{version}] - {project_date}"
underlines = ["", "", ""]
wrap = true
[[tool.towncrier.type]]
directory = "added"
name = "Added"
showcontent = true
[[tool.towncrier.type]]
directory = "changed"
name = "Changed"
showcontent = true
[[tool.towncrier.type]]
directory = "deprecated"
name = "Deprecated"
showcontent = true
[[tool.towncrier.type]]
directory = "removed"
name = "Removed"
showcontent = true
[[tool.towncrier.type]]
directory = "fixed"
name = "Fixed"
showcontent = true
[[tool.towncrier.type]]
directory = "security"
name = "Security"
showcontent = true

View File

@@ -103,7 +103,7 @@ TESTS_07 = [
("07o-interruptible-assemblyai.py", EVAL_SIMPLE_MATH),
("07q-interruptible-rime.py", EVAL_SIMPLE_MATH),
("07q-interruptible-rime-http.py", EVAL_SIMPLE_MATH),
("07r-interruptible-riva-nim.py", EVAL_SIMPLE_MATH),
("07r-interruptible-nvidia.py", EVAL_SIMPLE_MATH),
("07s-interruptible-google-audio-in.py", EVAL_SIMPLE_MATH),
("07t-interruptible-fish.py", EVAL_SIMPLE_MATH),
("07v-interruptible-neuphonic.py", EVAL_SIMPLE_MATH),
@@ -136,7 +136,7 @@ TESTS_14 = [
("14g-function-calling-grok.py", EVAL_WEATHER),
("14h-function-calling-azure.py", EVAL_WEATHER),
("14i-function-calling-fireworks.py", EVAL_WEATHER),
("14j-function-calling-nim.py", EVAL_WEATHER),
("14j-function-calling-nvidia.py", EVAL_WEATHER),
("14k-function-calling-cerebras.py", EVAL_WEATHER),
("14m-function-calling-openrouter.py", EVAL_WEATHER),
("14n-function-calling-perplexity.py", EVAL_WEATHER),

View File

@@ -28,7 +28,6 @@ from pipecat.metrics.metrics import MetricsData, SmartTurnMetricsData
STOP_SECS = 3
PRE_SPEECH_MS = 0
MAX_DURATION_SECONDS = 8 # Max allowed segment duration
USE_ONLY_LAST_VAD_SEGMENT = True
class SmartTurnParams(BaseTurnParams):
@@ -43,8 +42,6 @@ class SmartTurnParams(BaseTurnParams):
stop_secs: float = STOP_SECS
pre_speech_ms: float = PRE_SPEECH_MS
max_duration_secs: float = MAX_DURATION_SECONDS
# not exposing this for now yet until the model can handle it.
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
class SmartTurnTimeoutException(Exception):
@@ -160,7 +157,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
state, result = await loop.run_in_executor(
self._executor, self._process_speech_segment, self._audio_buffer
)
if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT:
if state == EndOfTurnState.COMPLETE:
self._clear(state)
logger.debug(f"End of Turn result: {state}")
return state, result

View File

@@ -42,17 +42,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
Args:
smart_turn_model_path: Path to the ONNX model file. If this is not
set, the bundled smart-turn-v3.0 model will be used.
set, the bundled smart-turn-v3.1-cpu model will be used.
cpu_count: The number of CPUs to use for inference. Defaults to 1.
**kwargs: Additional arguments passed to BaseSmartTurn.
"""
super().__init__(**kwargs)
logger.debug("Loading Local Smart Turn v3 model...")
if not smart_turn_model_path:
# Load bundled model
model_name = "smart-turn-v3.0.onnx"
model_name = "smart-turn-v3.1-cpu.onnx"
package_path = "pipecat.audio.turn.smart_turn.data"
try:
@@ -70,6 +68,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
impresources.files(package_path).joinpath(model_name)
)
logger.debug(f"Loading Local Smart Turn v3.x model from {smart_turn_model_path}...")
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.inter_op_num_threads = 1
@@ -79,7 +79,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
logger.debug("Loaded Local Smart Turn v3")
logger.debug("Loaded Local Smart Turn v3.x")
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
"""Predict end-of-turn using local ONNX model."""

View File

@@ -18,8 +18,10 @@ from loguru import logger
from pipecat.audio.dtmf.types import KeypadEntry
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
EndFrame,
Frame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMMessagesUpdateFrame,
LLMTextFrame,
OutputDTMFUrgentFrame,
@@ -149,11 +151,18 @@ class IVRProcessor(FrameProcessor):
elif isinstance(frame, LLMTextFrame):
# Process text through the pattern aggregator
result = await self._aggregator.aggregate(frame.text)
if result:
async for result in self._aggregator.aggregate(frame.text):
# Push aggregated text that doesn't contain XML patterns
await self.push_frame(LLMTextFrame(result.text), direction)
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
# Flush any remaining text from the aggregator
remaining = await self._aggregator.flush()
if remaining:
await self.push_frame(LLMTextFrame(remaining.text), direction)
# Push the end frame
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -40,8 +40,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.services.llm_service import LLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
class NotifierGate(FrameProcessor):

View File

@@ -330,7 +330,7 @@ class TextFrame(DataFrame):
"""
text: str
skip_tts: bool = field(init=False)
skip_tts: Optional[bool] = field(init=False)
# Whether any necessary inter-frame (leading/trailing) spaces are already
# included in the text.
# NOTE: Ideally this would be available at init time with a default value,
@@ -343,7 +343,7 @@ class TextFrame(DataFrame):
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.skip_tts = None
self.includes_inter_frame_spaces = False
self.append_to_context = True
@@ -1632,22 +1632,22 @@ class LLMFullResponseStartFrame(ControlFrame):
more TextFrames and a final LLMFullResponseEndFrame.
"""
skip_tts: bool = field(init=False)
skip_tts: Optional[bool] = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.skip_tts = None
@dataclass
class LLMFullResponseEndFrame(ControlFrame):
"""Frame indicating the end of an LLM response."""
skip_tts: bool = field(init=False)
skip_tts: Optional[bool] = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.skip_tts = None
@dataclass

View File

@@ -9,7 +9,7 @@
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.base_notifier import BaseNotifier
class GatedLLMContextAggregator(FrameProcessor):

View File

@@ -83,8 +83,7 @@ class LLMTextProcessor(FrameProcessor):
await self._text_aggregator.reset()
async def _handle_llm_text(self, in_frame: LLMTextFrame):
aggregation = await self._text_aggregator.aggregate(in_frame.text)
if aggregation:
async for aggregation in self._text_aggregator.aggregate(in_frame.text):
out_frame = AggregatedTextFrame(
text=aggregation.text,
aggregated_by=aggregation.type,
@@ -92,15 +91,13 @@ class LLMTextProcessor(FrameProcessor):
out_frame.skip_tts = in_frame.skip_tts
await self.push_frame(out_frame)
async def _handle_llm_end(self, skip_tts: bool = False):
# Flush any remaining aggregated text at the end of the LLM response
aggregation = self._text_aggregator.text
await self._text_aggregator.reset()
text = aggregation.text.strip()
if text:
async def _handle_llm_end(self, skip_tts: Optional[bool] = None):
# Flush any remaining text
remaining = await self._text_aggregator.flush()
if remaining:
out_frame = AggregatedTextFrame(
text=text,
aggregated_by=aggregation.type,
text=remaining.text,
aggregated_by=remaining.type,
)
out_frame.skip_tts = skip_tts
await self.push_frame(out_frame)

View File

@@ -10,7 +10,7 @@ from typing import Awaitable, Callable, Tuple, Type
from pipecat.frames.frames import Frame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.base_notifier import BaseNotifier
class WakeNotifierFilter(FrameProcessor):

View File

@@ -12,6 +12,7 @@ management, and frame flow control mechanisms.
"""
import asyncio
import traceback
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Sequence, Tuple, Type
@@ -677,7 +678,17 @@ class FrameProcessor(BaseObject):
if not error.processor:
error.processor = self
await self._call_event_handler("on_error", error)
logger.error(f"{error.processor} error: {error.error}")
if error.exception:
tb = traceback.extract_tb(error.exception.__traceback__)
last = tb[-1]
error_message = (
f"{error.processor} exception ({last.filename}:{last.lineno}): {error.error}"
)
else:
error_message = f"{error.processor} error: {error.error}"
logger.error(error_message)
await self.push_frame(error, FrameDirection.UPSTREAM)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):

View File

@@ -935,8 +935,8 @@ class RTVIObserverParams:
system_logs_enabled: Indicates if system logs should be sent.
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
Note: if using this to avoid sending secure information, be sure to also disable
bot_llm_enabled to avoid leaking through LLM messages.
Note: if using this to avoid sending secure information, be sure to also disable
bot_llm_enabled to avoid leaking through LLM messages.
bot_output_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of

View File

@@ -56,6 +56,17 @@ def language_to_async_language(language: Language) -> Optional[str]:
Language.ES: "es",
Language.DE: "de",
Language.IT: "it",
Language.PT: "pt",
Language.NL: "nl",
Language.AR: "ar",
Language.RU: "ru",
Language.RO: "ro",
Language.JA: "ja",
Language.HE: "he",
Language.HY: "hy",
Language.TR: "tr",
Language.HI: "hi",
Language.ZH: "zh",
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
@@ -74,7 +85,7 @@ class AsyncAITTSService(InterruptibleTTSService):
language: Language to use for synthesis.
"""
language: Optional[Language] = Language.EN
language: Optional[Language] = None
def __init__(
self,
@@ -83,7 +94,7 @@ class AsyncAITTSService(InterruptibleTTSService):
voice_id: str,
version: str = "v1",
url: str = "wss://api.async.ai/text_to_speech/websocket/ws",
model: str = "asyncflow_v2.0",
model: str = "asyncflow_multilingual_v1.0",
sample_rate: Optional[int] = None,
encoding: str = "pcm_s16le",
container: str = "raw",
@@ -99,7 +110,7 @@ class AsyncAITTSService(InterruptibleTTSService):
https://docs.async.ai/list-voices-16699698e0
version: Async API version.
url: WebSocket URL for Async TTS API.
model: TTS model to use (e.g., "asyncflow_v2.0").
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
sample_rate: Audio sample rate.
encoding: Audio encoding format.
container: Audio container format.
@@ -128,7 +139,7 @@ class AsyncAITTSService(InterruptibleTTSService):
},
"language": self.language_to_service_language(params.language)
if params.language
else "en",
else None,
}
self.set_model_name(model)
@@ -357,7 +368,7 @@ class AsyncAIHttpTTSService(TTSService):
language: Language to use for synthesis.
"""
language: Optional[Language] = Language.EN
language: Optional[Language] = None
def __init__(
self,
@@ -365,7 +376,7 @@ class AsyncAIHttpTTSService(TTSService):
api_key: str,
voice_id: str,
aiohttp_session: aiohttp.ClientSession,
model: str = "asyncflow_v2.0",
model: str = "asyncflow_multilingual_v1.0",
url: str = "https://api.async.ai",
version: str = "v1",
sample_rate: Optional[int] = None,
@@ -380,7 +391,7 @@ class AsyncAIHttpTTSService(TTSService):
api_key: Async API key.
voice_id: ID of the voice to use for synthesis.
aiohttp_session: An aiohttp session for making HTTP requests.
model: TTS model to use (e.g., "asyncflow_v2.0").
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
url: Base URL for Async API.
version: API version string for Async API.
sample_rate: Audio sample rate.
@@ -404,7 +415,7 @@ class AsyncAIHttpTTSService(TTSService):
},
"language": self.language_to_service_language(params.language)
if params.language
else "en",
else None,
}
self.set_voice(voice_id)
self.set_model_name(model)

View File

@@ -20,7 +20,6 @@ from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
@@ -349,6 +348,7 @@ class CartesiaSTTService(WebsocketSTTService):
self._user_id,
time_now_iso8601(),
language,
result=data,
)
)
await self._handle_transcription(transcript, is_final, language)
@@ -361,5 +361,6 @@ class CartesiaSTTService(WebsocketSTTService):
self._user_id,
time_now_iso8601(),
language,
result=data,
)
)

View File

@@ -244,6 +244,11 @@ class DeepgramFluxSTTService(WebsocketSTTService):
additional_headers={"Authorization": f"Token {self._api_key}"},
)
headers = {
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
}
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
# Creating the receiver task
if not self._receive_task:
self._receive_task = self.create_task(

View File

@@ -234,6 +234,13 @@ class DeepgramSTTService(STTService):
if not await self._connection.start(options=self._settings, addons=self._addons):
await self.push_error(error_msg=f"Unable to connect to Deepgram")
else:
headers = {
k: v
for k, v in self._connection._socket.response.headers.items()
if k.startswith("dg-")
}
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
async def _disconnect(self):
if await self._connection.is_connected():

View File

@@ -71,7 +71,12 @@ class DeepgramTTSService(WebsocketTTSService):
encoding: Audio encoding format. Defaults to "linear16".
**kwargs: Additional arguments passed to parent InterruptibleTTSService class.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
super().__init__(
sample_rate=sample_rate,
pause_frame_processing=True,
push_stop_frames=True,
**kwargs,
)
self._api_key = api_key
self._base_url = base_url
@@ -165,6 +170,11 @@ class DeepgramTTSService(WebsocketTTSService):
self._websocket = await websocket_connect(url, additional_headers=headers)
headers = {
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
}
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
@@ -231,7 +241,6 @@ class DeepgramTTSService(WebsocketTTSService):
logger.trace(f"Received Flushed: {msg}")
# Flushed indicates the end of audio generation for the current buffer
# This happens after flush_audio() is called
await self.push_frame(TTSStoppedFrame())
elif msg_type == "Cleared":
logger.trace(f"Received Cleared: {msg}")
# Buffer has been cleared after interruption
@@ -286,7 +295,7 @@ class DeepgramTTSService(WebsocketTTSService):
speak_msg = {"type": "Speak", "text": text}
await self._get_websocket().send(json.dumps(speak_msg))
# The actual audio frames will be handled in _receive_messages
# The audio frames will be handled in _receive_messages
yield None
except Exception as e:

View File

@@ -160,7 +160,7 @@ def build_elevenlabs_voice_settings(
class PronunciationDictionaryLocator(BaseModel):
"""Locator for a pronunciation dictionary.
Attributes:
Parameters:
pronunciation_dictionary_id: The ID of the pronunciation dictionary.
version_id: The version ID of the pronunciation dictionary.
"""
@@ -731,10 +731,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._websocket.send(json.dumps(msg))
logger.trace(f"Created new context {self._context_id}")
await self._send_text(text)
await self.start_tts_usage_metrics(text)
else:
await self._send_text(text)
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:
yield TTSStoppedFrame()
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -0,0 +1,5 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

View File

@@ -0,0 +1,239 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Gradium's speech-to-text service implementation.
This module provides integration with Gradium's real-time speech-to-text
WebSocket API for streaming audio transcription.
"""
import base64
import json
from typing import AsyncGenerator
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
raise Exception(f"Missing module: {e}")
SAMPLE_RATE = 24000
class GradiumSTTService(WebsocketSTTService):
"""Gradium real-time speech-to-text service.
Provides real-time speech transcription using Gradium's WebSocket API.
Supports both interim and final transcriptions with configurable parameters
for audio processing and connection management.
"""
def __init__(
self,
*,
api_key: str,
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
json_config: str | None = None,
**kwargs,
):
"""Initialize the Gradium STT service.
Args:
api_key: Gradium API key for authentication.
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
json_config: Optional JSON configuration string for additional model settings.
**kwargs: Additional arguments passed to parent STTService class.
"""
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
self._api_key = api_key
self._api_endpoint_base_url = api_endpoint_base_url
self._websocket = None
self._json_config = json_config
self._receive_task = None
self._audio_buffer = bytearray()
self._chunk_size_ms = 80
self._chunk_size_bytes = 0
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
Returns:
True if metrics generation is supported.
"""
return True
async def start(self, frame: StartFrame):
"""Start the speech-to-text service.
Args:
frame: Start frame to begin processing.
"""
await super().start(frame)
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the speech-to-text service.
Args:
frame: End frame to stop processing.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the speech-to-text service.
Args:
frame: Cancel frame to abort processing.
"""
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text conversion.
Args:
audio: Raw audio bytes to process.
Yields:
None (processing handled via WebSocket messages).
"""
self._audio_buffer.extend(audio)
await self.start_ttfb_metrics()
await self.start_processing_metrics()
while len(self._audio_buffer) >= self._chunk_size_bytes:
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
chunk = base64.b64encode(chunk).decode("utf-8")
msg = {"type": "audio", "audio": chunk}
if self._websocket and self._websocket.state is State.OPEN:
await self._websocket.send(json.dumps(msg))
yield None
@traced_stt
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
"""Record transcription event for tracing."""
pass
async def _connect(self):
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _connect_websocket(self):
try:
if self._websocket and self._websocket.state is State.OPEN:
return
ws_url = self._api_endpoint_base_url
headers = {
"x-api-key": self._api_key,
"x-api-source": "pipecat",
}
self._websocket = await websocket_connect(
ws_url,
additional_headers=headers,
)
await self._call_event_handler("on_connected")
setup_msg = {
"type": "setup",
"input_format": "pcm",
}
if self._json_config is not None:
setup_msg["json_config"] = self._json_config
await self._websocket.send(json.dumps(setup_msg))
ready_msg = await self._websocket.recv()
ready_msg = json.loads(ready_msg)
if ready_msg["type"] == "error":
raise Exception(f"received error {ready_msg['message']}")
if ready_msg["type"] != "ready":
raise Exception(f"unexpected first message type {ready_msg['type']}")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise
async def _disconnect(self):
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _disconnect_websocket(self):
try:
if self._websocket and self._websocket.state is State.OPEN:
logger.debug("Disconnecting from Gradium STT")
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _process_messages(self):
async for message in self._get_websocket():
try:
data = json.loads(message)
await self._process_response(data)
except json.JSONDecodeError:
logger.warning(f"Received non-JSON message: {message}")
async def _receive_messages(self):
while True:
await self._process_messages()
logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting")
await self._connect_websocket()
async def _process_response(self, msg):
type_ = msg.get("type", "")
if type_ == "text":
await self._handle_text(msg["text"])
elif type_ == "end_of_stream":
await self._handle_end_of_stream()
elif type_ == "error":
await self.push_error(error_msg=f"Error: {msg}")
async def _handle_end_of_stream(self):
"""Handle termination message."""
logger.debug("Received end_of_stream message from server")
async def _handle_text(self, text: str):
"""Handle transcription results."""
await self.push_frame(
TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
)
)

View File

@@ -0,0 +1,315 @@
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
"""Gradium Text-to-Speech service implementation."""
import base64
import json
import uuid
from typing import Any, AsyncGenerator, Mapping, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleWordTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
from websockets import ConnectionClosedOK
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Gradium, you need to `pip install pipecat-ai[gradium]`.")
raise Exception(f"Missing module: {e}")
SAMPLE_RATE = 48000
class GradiumTTSService(InterruptibleWordTTSService):
"""Text-to-Speech service using Gradium's websocket API."""
class InputParams(BaseModel):
"""Configuration parameters for Gradium TTS service.
Parameters:
temp: Temperature to be used for generation, defaults to 0.6.
"""
temp: Optional[float] = 0.6
def __init__(
self,
*,
api_key: str,
voice_id: str = "YTpq7expH9539ERJ",
url: str = "wss://eu.api.gradium.ai/api/speech/tts",
model: str = "default",
json_config: Optional[str] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Gradium TTS service.
Args:
api_key: Gradium API key for authentication.
voice_id: the voice identifier.
url: Gradium websocket API endpoint.
model: Model ID to use for synthesis.
json_config: Optional JSON configuration string for additional model settings.
params: Additional configuration parameters.
**kwargs: Additional arguments passed to parent class.
"""
# Initialize with parent class settings for proper frame handling
super().__init__(
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=SAMPLE_RATE,
**kwargs,
)
params = params or GradiumTTSService.InputParams()
# Store service configuration
self._api_key = api_key
self._url = url
self._voice_id = voice_id
self._json_config = json_config
self._model = model
self._settings = {
"voice_id": voice_id,
"model_name": model,
"output_format": "pcm",
}
# State tracking
self._receive_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Gradium service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Update the TTS model.
Args:
model: The model name to use for synthesis.
"""
self._model = model
await super().set_model(model)
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""
prev_voice = self._voice_id
await super()._update_settings(settings)
if not prev_voice == self._voice_id:
self._settings["voice_id"] = self._voice_id
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
await self._disconnect()
await self._connect()
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Gradium API."""
return {"text": text, "type": "text"}
async def start(self, frame: StartFrame):
"""Start the service and establish websocket connection.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close connection.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel current operation and clean up.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
"""Establish websocket connection and start receive task."""
logger.debug(f"{self}: connecting")
# If the server disconnected, cancel the receive-task so that it can be reset below.
if self._websocket is None or self._websocket.state is not State.OPEN:
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._connect_websocket()
if self._websocket and not self._receive_task:
logger.debug(f"{self}: setting receive task")
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
logger.debug(f"{self}: disconnecting")
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
"""Connect to Gradium websocket API with configured settings."""
try:
if self._websocket and self._websocket.state is State.OPEN:
return
headers = {"x-api-key": self._api_key, "x-api-source": "pipecat"}
self._websocket = await websocket_connect(self._url, additional_headers=headers)
setup_msg = {
"type": "setup",
"output_format": "pcm",
"voice_id": self._voice_id,
}
if self._json_config is not None:
setup_msg["json_config"] = self._json_config
await self._websocket.send(json.dumps(setup_msg))
ready_msg = await self._websocket.recv()
ready_msg = json.loads(ready_msg)
if ready_msg["type"] == "error":
raise Exception(f"received error {ready_msg['message']}")
if ready_msg["type"] != "ready":
raise Exception(f"unexpected first message type {ready_msg['type']}")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
"""Close websocket connection and reset state."""
try:
await self.stop_all_metrics()
if self._websocket:
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
"""Get active websocket connection or raise exception."""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
"""Flush any pending audio synthesis."""
if not self._websocket:
return
try:
msg = {"type": "end_of_stream"}
await self._websocket.send(json.dumps(msg))
except ConnectionClosedOK:
logger.debug(f"{self}: connection closed normally during flush")
except Exception as e:
logger.error(f"{self} exception: {e}")
async def _receive_messages(self):
"""Process incoming websocket messages."""
# TODO(laurent): This should not be necessary as it should happen when
# receiving the messages but this does not seem to always be the case
# and that may lead to a busy polling loop.
if self._websocket and self._websocket.state is State.CLOSED:
raise ConnectionClosedOK(None, None)
async for message in self._get_websocket():
msg = json.loads(message)
if msg["type"] == "audio":
# Process audio chunk
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["audio"]),
sample_rate=self.sample_rate,
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "text":
await self.add_word_timestamps([(msg["text"], msg["start_s"])])
elif msg["type"] == "end_of_stream":
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
elif msg["type"] == "error":
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(error_msg=f"Error: {msg['message']}")
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push frame and handle end-of-turn conditions.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Gradium's streaming API.
Args:
text: The text to convert to speech.
Yields:
Frame: Audio frames containing the synthesized speech.
"""
_state = self._websocket.state if self._websocket is not None else None
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
try:
if not self._websocket or self._websocket.state is State.CLOSED:
self._websocket = None
await self._connect()
try:
yield TTSStartedFrame()
msg = self._build_msg(text=text)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -8,10 +8,14 @@ import base64
import os
from typing import Any, AsyncGenerator, Optional
import httpx
from loguru import logger
from pydantic import BaseModel
from pipecat import __version__
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
@@ -26,11 +30,7 @@ from pipecat.utils.tracing.service_decorators import traced_tts
try:
from hume import AsyncHumeClient
from hume.tts import (
FormatPcm,
PostedUtterance,
PostedUtteranceVoiceWithId,
)
from hume.tts import FormatPcm, PostedUtterance, PostedUtteranceVoiceWithId
from hume.tts.types import TimestampMessage
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
logger.error(f"Exception: {e}")
@@ -40,6 +40,12 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
# Tracking headers for Hume API requests
DEFAULT_HEADERS = {
"X-Hume-Client-Name": "pipecat",
"X-Hume-Client-Version": __version__,
}
class HumeTTSService(WordTTSService):
"""Hume Octave Text-to-Speech service.
@@ -104,7 +110,11 @@ class HumeTTSService(WordTTSService):
**kwargs,
)
self._client = AsyncHumeClient(api_key=api_key)
# Create a custom httpx.AsyncClient with tracking headers
# Headers are included in all requests made by the Hume SDK
self._http_client = httpx.AsyncClient(headers=DEFAULT_HEADERS)
self._client = AsyncHumeClient(api_key=api_key, httpx_client=self._http_client)
self._params = params or HumeTTSService.InputParams()
# Store voice in the base class (mirrors other services)
@@ -138,6 +148,26 @@ class HumeTTSService(WordTTSService):
self._cumulative_time = 0.0
self._started = False
async def stop(self, frame: EndFrame) -> None:
"""Stop the service and cleanup resources.
Args:
frame: The end frame.
"""
await super().stop(frame)
if hasattr(self, "_http_client") and self._http_client:
await self._http_client.aclose()
async def cancel(self, frame: CancelFrame) -> None:
"""Cancel the service and cleanup resources.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
if hasattr(self, "_http_client") and self._http_client:
await self._http_client.aclose()
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.

View File

@@ -173,16 +173,17 @@ class LLMService(AIService):
run_in_parallel: Whether to run function calls in parallel or sequentially.
Defaults to True.
**kwargs: Additional arguments passed to the parent AIService.
"""
super().__init__(**kwargs)
self._run_in_parallel = run_in_parallel
self._start_callbacks = {}
self._adapter = self.adapter_class()
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
self._sequential_runner_task: Optional[asyncio.Task] = None
self._tracing_enabled: bool = False
self._skip_tts: bool = False
self._skip_tts: Optional[bool] = None
self._register_event_handler("on_function_calls_started")
self._register_event_handler("on_completion_timeout")
@@ -293,7 +294,8 @@ class LLMService(AIService):
direction: The direction of frame pushing.
"""
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
frame.skip_tts = self._skip_tts
if self._skip_tts is not None:
frame.skip_tts = self._skip_tts
await super().push_frame(frame, direction)
@@ -435,6 +437,7 @@ class LLMService(AIService):
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
runner_items = []
for function_call in function_calls:
if function_call.function_name in self._functions.keys():
item = self._functions[function_call.function_name]
@@ -446,28 +449,20 @@ class LLMService(AIService):
)
continue
runner_item = FunctionCallRunnerItem(
registry_item=item,
function_name=function_call.function_name,
tool_call_id=function_call.tool_call_id,
arguments=function_call.arguments,
context=function_call.context,
runner_items.append(
FunctionCallRunnerItem(
registry_item=item,
function_name=function_call.function_name,
tool_call_id=function_call.tool_call_id,
arguments=function_call.arguments,
context=function_call.context,
)
)
if self._run_in_parallel:
task = self.create_task(self._run_function_call(runner_item))
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
else:
await self._sequential_runner_queue.put(runner_item)
async def _call_start_function(
self, context: OpenAILLMContext | LLMContext, function_name: str
):
if function_name in self._start_callbacks.keys():
await self._start_callbacks[function_name](function_name, self, context)
elif None in self._start_callbacks.keys():
return await self._start_callbacks[None](function_name, self, context)
if self._run_in_parallel:
await self._run_parallel_function_calls(runner_items)
else:
await self._run_sequential_function_calls(runner_items)
async def request_image_frame(
self,
@@ -540,6 +535,27 @@ class LLMService(AIService):
await task
del self._function_call_tasks[task]
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
tasks = []
for runner_item in runner_items:
task = self.create_task(self._run_function_call(runner_item))
tasks.append(task)
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
# Enqueue all function calls for background execution.
for runner_item in runner_items:
await self._sequential_runner_queue.put(runner_item)
async def _call_start_function(
self, context: OpenAILLMContext | LLMContext, function_name: str
):
if function_name in self._start_callbacks.keys():
await self._start_callbacks[function_name](function_name, self, context)
elif None in self._start_callbacks.keys():
return await self._start_callbacks[None](function_name, self, context)
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
if runner_item.function_name in self._functions.keys():
item = self._functions[runner_item.function_name]
@@ -623,20 +639,19 @@ class LLMService(AIService):
name = runner_item.function_name
tool_call_id = runner_item.tool_call_id
# We remove the callback because we are going to cancel the task
# now, otherwise we will be removing it from the set while we
# are iterating.
task.remove_done_callback(self._function_call_task_finished)
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
await self.cancel_task(task)
if task:
# We remove the callback because we are going to cancel the
# task next, otherwise we will be removing it from the set
# while we are iterating.
task.remove_done_callback(self._function_call_task_finished)
await self.cancel_task(task)
cancelled_tasks.add(task)
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
await self.push_frame(frame)
cancelled_tasks.add(task)
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
# Remove all cancelled tasks from our set.

View File

@@ -7,7 +7,7 @@
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
import json
from typing import Any, Dict, List, TypeAlias
from typing import Any, Callable, Dict, List, Optional, TypeAlias
from loguru import logger
@@ -46,17 +46,24 @@ class MCPClient(BaseObject):
def __init__(
self,
server_params: ServerParameters,
tools_filter: Optional[List[str]] = None,
tools_output_filters: Optional[Dict[str, Callable[[Any], Any]]] = None,
**kwargs,
):
"""Initialize the MCP client with server parameters.
Args:
server_params: Server connection parameters (stdio or SSE).
tools_filter: Optional list of tool names to register. If None, all tools are registered.
tools_output_filters: Optional dict mapping tool names to filter functions that process tool outputs.
Each filter function receives the raw tool output (any type) and returns the processed output (any type).
**kwargs: Additional arguments passed to the parent BaseObject.
"""
super().__init__(**kwargs)
self._server_params = server_params
self._session = ClientSession
self._tools_filter = tools_filter
self._tools_output_filters = tools_output_filters or {}
if isinstance(server_params, StdioServerParameters):
self._client = stdio_client
@@ -264,13 +271,26 @@ class MCPClient(BaseObject):
else:
# logger.debug(f"Non-text result content: '{content}'")
pass
logger.info(f"Tool '{function_name}' completed successfully")
logger.debug(f"Final response: {response}")
else:
logger.error(f"Error getting content from {function_name} results.")
final_response = response if len(response) else "Sorry, could not call the mcp tool"
await result_callback(final_response)
# Apply output filter if configured for this tool
if function_name in self._tools_output_filters:
try:
response = self._tools_output_filters[function_name](response)
logger.debug(f"Final response (after filter): {response}")
except Exception:
logger.error(f"Error applying output filter for {function_name}")
response = ""
if response and len(response) and isinstance(response, str):
logger.info(f"Tool '{function_name}' completed successfully")
logger.debug(f"Final response: {response}")
else:
response = "Sorry, could not call the mcp tool"
await result_callback(response)
async def _list_tools_helper(self, session):
available_tools = await session.list_tools()
@@ -283,6 +303,12 @@ class MCPClient(BaseObject):
for tool in available_tools.tools:
tool_name = tool.name
# Apply tools filter if configured
if self._tools_filter and tool_name not in self._tools_filter:
logger.debug(f"Skipping tool '{tool_name}' - not in allowed tools list")
continue
logger.debug(f"Processing tool: {tool_name}")
logger.debug(f"Tool description: {tool.description}")

View File

@@ -8,98 +8,23 @@
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
Microservice) API while maintaining compatibility with the OpenAI-style interface.
.. deprecated:: 0.0.96
This module is deprecated. Please NvidiaLLMService from
pipecat.services.nvidia.llm instead.
"""
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.llm import OpenAILLMService
import warnings
from pipecat.services.nvidia.llm import NvidiaLLMService
class NimLLMService(OpenAILLMService):
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"NimLLMService from pipecat.services.nim.llm is deprecated. "
"Please use NvidiaLLMService from pipecat.services.nvidia.llm instead.",
DeprecationWarning,
stacklevel=2,
)
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
compatibility with the OpenAI-style interface. It specifically handles the difference
in token usage reporting between NIM (incremental) and OpenAI (final summary).
"""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://integrate.api.nvidia.com/v1",
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
**kwargs,
):
"""Initialize the NimLLMService.
Args:
api_key: The API key for accessing NVIDIA's NIM API.
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
# Counters for accumulating token usage metrics
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = False
async def _process_context(self, context: OpenAILLMContext | LLMContext):
"""Process a context through the LLM and accumulate token usage metrics.
This method overrides the parent class implementation to handle NVIDIA's
incremental token reporting style, accumulating the counts and reporting
them once at the end of processing.
Args:
context: The context to process, containing messages and other information
needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = True
try:
await super()._process_context(context)
finally:
self._is_processing = False
# Report final accumulated token usage at the end of processing
if self._prompt_tokens > 0 or self._completion_tokens > 0:
self._total_tokens = self._prompt_tokens + self._completion_tokens
tokens = LLMTokenUsage(
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
)
await super().start_llm_usage_metrics(tokens)
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
"""Accumulate token usage metrics during processing.
This method intercepts the incremental token updates from NVIDIA's API
and accumulates them instead of passing each update to the metrics system.
The final accumulated totals are reported at the end of processing.
Args:
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
return
# Record prompt tokens the first time we see them
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
self._prompt_tokens = tokens.prompt_tokens
self._has_reported_prompt_tokens = True
# Update completion tokens count if it has increased
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens
NimLLMService = NvidiaLLMService

View File

View File

@@ -0,0 +1,105 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA NIM API service implementation.
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
Microservice) API while maintaining compatibility with the OpenAI-style interface.
"""
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.llm import OpenAILLMService
class NvidiaLLMService(OpenAILLMService):
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
compatibility with the OpenAI-style interface. It specifically handles the difference
in token usage reporting between NIM (incremental) and OpenAI (final summary).
"""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://integrate.api.nvidia.com/v1",
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
**kwargs,
):
"""Initialize the NvidiaLLMService.
Args:
api_key: The API key for accessing NVIDIA's NIM API.
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
# Counters for accumulating token usage metrics
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = False
async def _process_context(self, context: OpenAILLMContext | LLMContext):
"""Process a context through the LLM and accumulate token usage metrics.
This method overrides the parent class implementation to handle NVIDIA's
incremental token reporting style, accumulating the counts and reporting
them once at the end of processing.
Args:
context: The context to process, containing messages and other information
needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = True
try:
await super()._process_context(context)
finally:
self._is_processing = False
# Report final accumulated token usage at the end of processing
if self._prompt_tokens > 0 or self._completion_tokens > 0:
self._total_tokens = self._prompt_tokens + self._completion_tokens
tokens = LLMTokenUsage(
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
)
await super().start_llm_usage_metrics(tokens)
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
"""Accumulate token usage metrics during processing.
This method intercepts the incremental token updates from NVIDIA's API
and accumulates them instead of passing each update to the metrics system.
The final accumulated totals are reported at the end of processing.
Args:
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
return
# Record prompt tokens the first time we see them
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
self._prompt_tokens = tokens.prompt_tokens
self._has_reported_prompt_tokens = True
# Update completion tokens count if it has increased
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens

View File

@@ -0,0 +1,663 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
import asyncio
from concurrent.futures import CancelledError as FuturesCancelledError
from typing import AsyncGenerator, List, Mapping, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import SegmentedSTTService, STTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.")
raise Exception(f"Missing module: {e}")
def language_to_nvidia_riva_language(language: Language) -> Optional[str]:
"""Maps Language enum to NVIDIA Riva ASR language codes.
Source:
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
Args:
language: Language enum value.
Returns:
Optional[str]: NVIDIA Riva language code or None if not supported.
"""
LANGUAGE_MAP = {
# Arabic
Language.AR: "ar-AR",
# English
Language.EN: "en-US", # Default to US
Language.EN_US: "en-US",
Language.EN_GB: "en-GB",
# French
Language.FR: "fr-FR",
Language.FR_FR: "fr-FR",
# German
Language.DE: "de-DE",
Language.DE_DE: "de-DE",
# Hindi
Language.HI: "hi-IN",
Language.HI_IN: "hi-IN",
# Italian
Language.IT: "it-IT",
Language.IT_IT: "it-IT",
# Japanese
Language.JA: "ja-JP",
Language.JA_JP: "ja-JP",
# Korean
Language.KO: "ko-KR",
Language.KO_KR: "ko-KR",
# Portuguese
Language.PT: "pt-BR", # Default to Brazilian
Language.PT_BR: "pt-BR",
# Russian
Language.RU: "ru-RU",
Language.RU_RU: "ru-RU",
# Spanish
Language.ES: "es-ES", # Default to Spain
Language.ES_ES: "es-ES",
Language.ES_US: "es-US", # US Spanish
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
class NvidiaSTTService(STTService):
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
through streaming recognition. Supports interim results and continuous audio
processing for low-latency applications.
"""
class InputParams(BaseModel):
"""Configuration parameters for NVIDIA Riva STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
"""
language: Optional[Language] = Language.EN_US
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
"model_name": "parakeet-ctc-1.1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva STT service.
Args:
api_key: NVIDIA API key for authentication.
server: NVIDIA Riva server address. Defaults to NVIDIA Cloud Function endpoint.
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for NVIDIA Riva.
**kwargs: Additional arguments passed to STTService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or NvidiaSTTService.InputParams()
self._api_key = api_key
self._profanity_filter = False
self._automatic_punctuation = True
self._no_verbatim_transcripts = False
self._language_code = params.language
self._boosted_lm_words = None
self._boosted_lm_score = 4.0
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
self._function_id = model_function_map.get("function_id")
self._settings = {
"language": str(params.language),
"profanity_filter": self._profanity_filter,
"automatic_punctuation": self._automatic_punctuation,
"verbatim_transcripts": not self._no_verbatim_transcripts,
"boosted_lm_words": self._boosted_lm_words,
"boosted_lm_score": self._boosted_lm_score,
}
self.set_model_name(model_function_map.get("model_name"))
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._asr_service = riva.client.ASRService(auth)
self._queue = None
self._config = None
self._thread_task = None
self._response_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
False - this service does not support metrics generation.
"""
return False
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Start the NVIDIA Riva STT service and initialize streaming configuration.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
if self._config:
return
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self.sample_rate,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
self._queue = asyncio.Queue()
if not self._thread_task:
self._thread_task = self.create_task(self._thread_task_handler())
if not self._response_task:
self._response_queue = asyncio.Queue()
self._response_task = self.create_task(self._response_task_handler())
async def stop(self, frame: EndFrame):
"""Stop the NVIDIA Riva STT service and clean up resources.
Args:
frame: EndFrame indicating pipeline stop.
"""
await super().stop(frame)
await self._stop_tasks()
async def cancel(self, frame: CancelFrame):
"""Cancel the NVIDIA Riva STT service operation.
Args:
frame: CancelFrame indicating operation cancellation.
"""
await super().cancel(frame)
await self._stop_tasks()
async def _stop_tasks(self):
if self._thread_task:
await self.cancel_task(self._thread_task)
self._thread_task = None
if self._response_task:
await self.cancel_task(self._response_task)
self._response_task = None
def _response_handler(self):
responses = self._asr_service.streaming_response_generator(
audio_chunks=self,
streaming_config=self._config,
)
for response in responses:
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
self._response_queue.put(response), self.get_event_loop()
)
async def _thread_task_handler(self):
try:
self._thread_running = True
await asyncio.to_thread(self._response_handler)
except asyncio.CancelledError:
self._thread_running = False
raise
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def _handle_response(self, response):
for result in response.results:
if result and not result.alternatives:
continue
transcript = result.alternatives[0].transcript
if transcript and len(transcript) > 0:
await self.stop_ttfb_metrics()
if result.is_final:
await self.stop_processing_metrics()
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
await self._handle_transcription(
transcript=transcript,
is_final=result.is_final,
language=self._language_code,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
async def _response_task_handler(self):
while True:
response = await self._response_queue.get()
await self._handle_response(response)
self._response_queue.task_done()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
None - transcription results are pushed to the pipeline via frames.
"""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._queue.put(audio)
yield None
def __next__(self) -> bytes:
"""Get the next audio chunk for NVIDIA Riva processing.
Returns:
Audio bytes from the queue.
Raises:
StopIteration: When the thread is no longer running.
"""
if not self._thread_running:
raise StopIteration
try:
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
return future.result()
except FuturesCancelledError:
raise StopIteration
def __iter__(self):
"""Return iterator for audio chunk processing.
Returns:
Self as iterator.
"""
return self
class NvidiaSegmentedSTTService(SegmentedSTTService):
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
transcription on audio segments. It inherits from SegmentedSTTService to handle
audio buffering and speech detection.
"""
class InputParams(BaseModel):
"""Configuration parameters for NVIDIA Riva segmented STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
profanity_filter: Whether to filter profanity from results.
automatic_punctuation: Whether to add automatic punctuation.
verbatim_transcripts: Whether to return verbatim transcripts.
boosted_lm_words: List of words to boost in language model.
boosted_lm_score: Score boost for specified words.
"""
language: Optional[Language] = Language.EN_US
profanity_filter: bool = False
automatic_punctuation: bool = True
verbatim_transcripts: bool = False
boosted_lm_words: Optional[List[str]] = None
boosted_lm_score: float = 4.0
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
"model_name": "canary-1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva segmented STT service.
Args:
api_key: NVIDIA API key for authentication
server: NVIDIA Riva server address (defaults to NVIDIA Cloud Function endpoint)
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
params: Additional configuration parameters for NVIDIA Riva
**kwargs: Additional arguments passed to SegmentedSTTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or NvidiaSegmentedSTTService.InputParams()
# Set model name
self.set_model_name(model_function_map.get("model_name"))
# Initialize NVIDIA Riva settings
self._api_key = api_key
self._server = server
self._function_id = model_function_map.get("function_id")
self._model_name = model_function_map.get("model_name")
# Store the language as a Language enum and as a string
self._language_enum = params.language or Language.EN_US
self._language = self.language_to_service_language(self._language_enum) or "en-US"
# Configure transcription parameters
self._profanity_filter = params.profanity_filter
self._automatic_punctuation = params.automatic_punctuation
self._verbatim_transcripts = params.verbatim_transcripts
self._boosted_lm_words = params.boosted_lm_words
self._boosted_lm_score = params.boosted_lm_score
# Voice activity detection thresholds (use NVIDIA Riva defaults)
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
# Create NVIDIA Riva client
self._config = None
self._asr_service = None
self._settings = {"language": self._language_enum}
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert pipecat Language enum to NVIDIA Riva's language code.
Args:
language: Language enum value.
Returns:
NVIDIA Riva language code or None if not supported.
"""
return language_to_nvidia_riva_language(language)
def _initialize_client(self):
"""Initialize the NVIDIA Riva ASR client with authentication metadata."""
if self._asr_service is not None:
return
# Set up authentication metadata for NVIDIA Cloud Functions
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
# Create authenticated client
auth = riva.client.Auth(None, True, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
def _create_recognition_config(self):
"""Create the NVIDIA Riva ASR recognition configuration."""
# Create base configuration
config = riva.client.RecognitionConfig(
language_code=self._language, # Now using the string, not a tuple
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=self._verbatim_transcripts,
)
# Add word boosting if specified
if self._boosted_lm_words:
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
# Add voice activity detection parameters
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
# Add any custom configuration
if self._custom_configuration:
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
return config
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True - this service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Initialize the service when the pipeline starts.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_recognition_config()
async def set_language(self, language: Language):
"""Set the language for the STT service.
Args:
language: Target language for transcription.
"""
logger.info(f"Switching STT language to: [{language}]")
self._language_enum = language
self._language = self.language_to_service_language(language) or "en-US"
self._settings["language"] = language
# Update configuration with new language
if self._config:
self._config.language_code = self._language
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Transcribe an audio segment.
Args:
audio: Raw audio bytes in WAV format (already converted by base class).
Yields:
Frame: TranscriptionFrame containing the transcribed text.
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Make sure the client is initialized
if self._asr_service is None:
self._initialize_client()
# Make sure the config is created
if self._config is None:
self._config = self._create_recognition_config()
# Type assertion to satisfy the IDE
assert self._asr_service is not None, "ASR service not initialized"
assert self._config is not None, "Recognition config not created"
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Process the response - handle different possible return types
try:
# If it's a future-like object, get the result
if hasattr(raw_response, "result"):
response = raw_response.result()
else:
response = raw_response
# Process transcription results
transcription_found = False
# Now we can safely check results
# Type hint for the IDE
results = getattr(response, "results", [])
for result in results:
alternatives = getattr(result, "alternatives", [])
if alternatives:
text = alternatives[0].transcript.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
self._language_enum,
)
transcription_found = True
await self._handle_transcription(text, True, self._language_enum)
if not transcription_found:
logger.debug("No transcription results found in NVIDIA Riva response")
except AttributeError as ae:
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")

View File

@@ -0,0 +1,187 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva text-to-speech service implementation.
This module provides integration with NVIDIA Riva's TTS services through
gRPC API for high-quality speech synthesis.
"""
import asyncio
import os
from typing import AsyncGenerator, Mapping, Optional
from pipecat.utils.tracing.service_decorators import traced_tts
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
raise Exception(f"Missing module: {e}")
NVIDIA_TTS_TIMEOUT_SECS = 5
class NvidiaTTSService(TTSService):
"""NVIDIA Riva text-to-speech service.
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
cloud-based TTS models. Supports multiple voices, languages, and
configurable quality settings.
"""
class InputParams(BaseModel):
"""Input parameters for Riva TTS configuration.
Parameters:
language: Language code for synthesis. Defaults to US English.
quality: Audio quality setting (0-100). Defaults to 20.
"""
language: Optional[Language] = Language.EN_US
quality: Optional[int] = 20
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
sample_rate: Optional[int] = None,
model_function_map: Mapping[str, str] = {
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
"model_name": "magpie-tts-multilingual",
},
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva TTS service.
Args:
api_key: NVIDIA API key for authentication.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for the TTS model.
params: Additional configuration parameters for TTS synthesis.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or NvidiaTTSService.InputParams()
self._api_key = api_key
self._voice_id = voice_id
self._language_code = params.language
self._quality = params.quality
self._function_id = model_function_map.get("function_id")
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
config_response = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
async def set_model(self, model: str):
"""Attempt to set the TTS model.
Note: Model cannot be changed after initialization for Riva service.
Args:
model: The model name to set (operation not supported).
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech data.
"""
def read_audio_responses(queue: asyncio.Queue):
def add_response(r):
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
try:
responses = self._service.synthesize_online(
text,
self._voice_id,
self._language_code,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._quality,
custom_dictionary={},
)
for r in responses:
add_response(r)
add_response(None)
except Exception as e:
logger.error(f"{self} exception: {e}")
add_response(None)
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
try:
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)
# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()

View File

@@ -133,6 +133,7 @@ class BaseOpenAILLMService(LLMService):
self._retry_timeout_secs = retry_timeout_secs
self._retry_on_timeout = retry_on_timeout
self.set_model_name(model)
self._full_model_name: str = ""
self._client = self.create_client(
api_key=api_key,
base_url=base_url,
@@ -185,6 +186,22 @@ class BaseOpenAILLMService(LLMService):
"""
return True
def set_full_model_name(self, full_model_name: str):
"""Set the full AI model name.
Args:
full_model_name: The full name of the AI model to use.
"""
self._full_model_name = full_model_name
def get_full_model_name(self):
"""Get the current full model name.
Returns:
The full name of the AI model being used.
"""
return self._full_model_name
async def get_chat_completions(
self, params_from_context: OpenAILLMInvocationParams
) -> AsyncStream[ChatCompletionChunk]:
@@ -360,6 +377,9 @@ class BaseOpenAILLMService(LLMService):
)
await self.start_llm_usage_metrics(tokens)
if chunk.model and self.get_full_model_name() != chunk.model:
self.set_full_model_name(chunk.model)
if chunk.choices is None or len(chunk.choices) == 0:
continue

View File

@@ -4,707 +4,32 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription.
import asyncio
from concurrent.futures import CancelledError as FuturesCancelledError
from typing import AsyncGenerator, List, Mapping, Optional
.. deprecated:: 0.0.96
This module is deprecated. Please NvidiaSTTService from
pipecat.services.nvidia.stt instead.
"""
from loguru import logger
from pydantic import BaseModel
import warnings
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
from pipecat.services.nvidia.stt import (
NvidiaSegmentedSTTService,
NvidiaSTTService,
language_to_nvidia_riva_language,
)
from pipecat.services.stt_service import SegmentedSTTService, STTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}")
def language_to_riva_language(language: Language) -> Optional[str]:
"""Maps Language enum to Riva ASR language codes.
Source:
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
Args:
language: Language enum value.
Returns:
Optional[str]: Riva language code or None if not supported.
"""
LANGUAGE_MAP = {
# Arabic
Language.AR: "ar-AR",
# English
Language.EN: "en-US", # Default to US
Language.EN_US: "en-US",
Language.EN_GB: "en-GB",
# French
Language.FR: "fr-FR",
Language.FR_FR: "fr-FR",
# German
Language.DE: "de-DE",
Language.DE_DE: "de-DE",
# Hindi
Language.HI: "hi-IN",
Language.HI_IN: "hi-IN",
# Italian
Language.IT: "it-IT",
Language.IT_IT: "it-IT",
# Japanese
Language.JA: "ja-JP",
Language.JA_JP: "ja-JP",
# Korean
Language.KO: "ko-KR",
Language.KO_KR: "ko-KR",
# Portuguese
Language.PT: "pt-BR", # Default to Brazilian
Language.PT_BR: "pt-BR",
# Russian
Language.RU: "ru-RU",
Language.RU_RU: "ru-RU",
# Spanish
Language.ES: "es-ES", # Default to Spain
Language.ES_ES: "es-ES",
Language.ES_US: "es-US", # US Spanish
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
class RivaSTTService(STTService):
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
through streaming recognition. Supports interim results and continuous audio
processing for low-latency applications.
"""
class InputParams(BaseModel):
"""Configuration parameters for Riva STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
"""
language: Optional[Language] = Language.EN_US
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
"model_name": "parakeet-ctc-1.1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Riva STT service.
Args:
api_key: NVIDIA API key for authentication.
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for Riva.
**kwargs: Additional arguments passed to STTService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RivaSTTService.InputParams()
self._api_key = api_key
self._profanity_filter = False
self._automatic_punctuation = True
self._no_verbatim_transcripts = False
self._language_code = params.language
self._boosted_lm_words = None
self._boosted_lm_score = 4.0
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
self._function_id = model_function_map.get("function_id")
self._settings = {
"language": str(params.language),
"profanity_filter": self._profanity_filter,
"automatic_punctuation": self._automatic_punctuation,
"verbatim_transcripts": not self._no_verbatim_transcripts,
"boosted_lm_words": self._boosted_lm_words,
"boosted_lm_score": self._boosted_lm_score,
}
self.set_model_name(model_function_map.get("model_name"))
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._asr_service = riva.client.ASRService(auth)
self._queue = None
self._config = None
self._thread_task = None
self._response_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
False - this service does not support metrics generation.
"""
return False
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Start the Riva STT service and initialize streaming configuration.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
if self._config:
return
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self.sample_rate,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
self._queue = asyncio.Queue()
if not self._thread_task:
self._thread_task = self.create_task(self._thread_task_handler())
if not self._response_task:
self._response_queue = asyncio.Queue()
self._response_task = self.create_task(self._response_task_handler())
async def stop(self, frame: EndFrame):
"""Stop the Riva STT service and clean up resources.
Args:
frame: EndFrame indicating pipeline stop.
"""
await super().stop(frame)
await self._stop_tasks()
async def cancel(self, frame: CancelFrame):
"""Cancel the Riva STT service operation.
Args:
frame: CancelFrame indicating operation cancellation.
"""
await super().cancel(frame)
await self._stop_tasks()
async def _stop_tasks(self):
if self._thread_task:
await self.cancel_task(self._thread_task)
self._thread_task = None
if self._response_task:
await self.cancel_task(self._response_task)
self._response_task = None
def _response_handler(self):
responses = self._asr_service.streaming_response_generator(
audio_chunks=self,
streaming_config=self._config,
)
for response in responses:
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
self._response_queue.put(response), self.get_event_loop()
)
async def _thread_task_handler(self):
try:
self._thread_running = True
await asyncio.to_thread(self._response_handler)
except asyncio.CancelledError:
self._thread_running = False
raise
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def _handle_response(self, response):
for result in response.results:
if result and not result.alternatives:
continue
transcript = result.alternatives[0].transcript
if transcript and len(transcript) > 0:
await self.stop_ttfb_metrics()
if result.is_final:
await self.stop_processing_metrics()
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
await self._handle_transcription(
transcript=transcript,
is_final=result.is_final,
language=self._language_code,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
async def _response_task_handler(self):
while True:
response = await self._response_queue.get()
await self._handle_response(response)
self._response_queue.task_done()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
None - transcription results are pushed to the pipeline via frames.
"""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._queue.put(audio)
yield None
def __next__(self) -> bytes:
"""Get the next audio chunk for Riva processing.
Returns:
Audio bytes from the queue.
Raises:
StopIteration: When the thread is no longer running.
"""
if not self._thread_running:
raise StopIteration
try:
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
return future.result()
except FuturesCancelledError:
raise StopIteration
def __iter__(self):
"""Return iterator for audio chunk processing.
Returns:
Self as iterator.
"""
return self
class RivaSegmentedSTTService(SegmentedSTTService):
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
transcription on audio segments. It inherits from SegmentedSTTService to handle
audio buffering and speech detection.
"""
class InputParams(BaseModel):
"""Configuration parameters for Riva segmented STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
profanity_filter: Whether to filter profanity from results.
automatic_punctuation: Whether to add automatic punctuation.
verbatim_transcripts: Whether to return verbatim transcripts.
boosted_lm_words: List of words to boost in language model.
boosted_lm_score: Score boost for specified words.
"""
language: Optional[Language] = Language.EN_US
profanity_filter: bool = False
automatic_punctuation: bool = True
verbatim_transcripts: bool = False
boosted_lm_words: Optional[List[str]] = None
boosted_lm_score: float = 4.0
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
"model_name": "canary-1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Riva segmented STT service.
Args:
api_key: NVIDIA API key for authentication
server: Riva server address (defaults to NVIDIA Cloud Function endpoint)
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
params: Additional configuration parameters for Riva
**kwargs: Additional arguments passed to SegmentedSTTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RivaSegmentedSTTService.InputParams()
# Set model name
self.set_model_name(model_function_map.get("model_name"))
# Initialize Riva settings
self._api_key = api_key
self._server = server
self._function_id = model_function_map.get("function_id")
self._model_name = model_function_map.get("model_name")
# Store the language as a Language enum and as a string
self._language_enum = params.language or Language.EN_US
self._language = self.language_to_service_language(self._language_enum) or "en-US"
# Configure transcription parameters
self._profanity_filter = params.profanity_filter
self._automatic_punctuation = params.automatic_punctuation
self._verbatim_transcripts = params.verbatim_transcripts
self._boosted_lm_words = params.boosted_lm_words
self._boosted_lm_score = params.boosted_lm_score
# Voice activity detection thresholds (use Riva defaults)
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
# Create Riva client
self._config = None
self._asr_service = None
self._settings = {"language": self._language_enum}
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert pipecat Language enum to Riva's language code.
Args:
language: Language enum value.
Returns:
Riva language code or None if not supported.
"""
return language_to_riva_language(language)
def _initialize_client(self):
"""Initialize the Riva ASR client with authentication metadata."""
if self._asr_service is not None:
return
# Set up authentication metadata for NVIDIA Cloud Functions
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
# Create authenticated client
auth = riva.client.Auth(None, True, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
logger.info(f"Initialized RivaSegmentedSTTService with model: {self.model_name}")
def _create_recognition_config(self):
"""Create the Riva ASR recognition configuration."""
# Create base configuration
config = riva.client.RecognitionConfig(
language_code=self._language, # Now using the string, not a tuple
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=self._verbatim_transcripts,
)
# Add word boosting if specified
if self._boosted_lm_words:
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
# Add voice activity detection parameters
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
# Add any custom configuration
if self._custom_configuration:
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
return config
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True - this service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Initialize the service when the pipeline starts.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_recognition_config()
async def set_language(self, language: Language):
"""Set the language for the STT service.
Args:
language: Target language for transcription.
"""
logger.info(f"Switching STT language to: [{language}]")
self._language_enum = language
self._language = self.language_to_service_language(language) or "en-US"
self._settings["language"] = language
# Update configuration with new language
if self._config:
self._config.language_code = self._language
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Transcribe an audio segment.
Args:
audio: Raw audio bytes in WAV format (already converted by base class).
Yields:
Frame: TranscriptionFrame containing the transcribed text.
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Make sure the client is initialized
if self._asr_service is None:
self._initialize_client()
# Make sure the config is created
if self._config is None:
self._config = self._create_recognition_config()
# Type assertion to satisfy the IDE
assert self._asr_service is not None, "ASR service not initialized"
assert self._config is not None, "Recognition config not created"
# Process audio with Riva ASR - explicitly request non-future response
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Process the response - handle different possible return types
try:
# If it's a future-like object, get the result
if hasattr(raw_response, "result"):
response = raw_response.result()
else:
response = raw_response
# Process transcription results
transcription_found = False
# Now we can safely check results
# Type hint for the IDE
results = getattr(response, "results", [])
for result in results:
alternatives = getattr(result, "alternatives", [])
if alternatives:
text = alternatives[0].transcript.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
self._language_enum,
)
transcription_found = True
await self._handle_transcription(text, True, self._language_enum)
if not transcription_found:
logger.debug("No transcription results found in Riva response")
except AttributeError as ae:
yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}")
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class ParakeetSTTService(RivaSTTService):
"""Deprecated speech-to-text service using NVIDIA Parakeet models.
.. deprecated:: 0.0.66
This class is deprecated. Use `RivaSTTService` instead for equivalent functionality
with Parakeet models by specifying the appropriate model_function_map.
"""
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
"model_name": "parakeet-ctc-1.1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[RivaSTTService.InputParams] = None, # Use parent class's type
**kwargs,
):
"""Initialize the Parakeet STT service.
Args:
api_key: NVIDIA API key for authentication.
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
model_function_map: Mapping containing 'function_id' and 'model_name' for Parakeet model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for Riva.
**kwargs: Additional arguments passed to RivaSTTService.
"""
super().__init__(
api_key=api_key,
server=server,
model_function_map=model_function_map,
sample_rate=sample_rate,
params=params,
**kwargs,
)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.",
DeprecationWarning,
)
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"RivaSTTService and ParakeetSTTService "
"from pipecat.services.riva.stt is deprecated. "
"Please use NvidiaSTTService from pipecat.services.nvidia.stt instead.",
DeprecationWarning,
stacklevel=2,
)
RivaSTTService = NvidiaSTTService
language_to_riva_language = language_to_nvidia_riva_language
RivaSegmentedSTTService = NvidiaSegmentedSTTService
ParakeetSTTService = NvidiaSTTService

View File

@@ -8,231 +8,26 @@
This module provides integration with NVIDIA Riva's TTS services through
gRPC API for high-quality speech synthesis.
.. deprecated:: 0.0.96
This module is deprecated. Please NvidiaTTSService from
pipecat.services.nvidia.tts instead.
"""
import asyncio
import os
from typing import AsyncGenerator, Mapping, Optional
import warnings
from pipecat.utils.tracing.service_decorators import traced_tts
from pipecat.services.nvidia.tts import NVIDIA_TTS_TIMEOUT_SECS, NvidiaTTSService
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"FastPitchTTSService and RivaTTSService "
"from pipecat.services.nim.llm are deprecated. "
"Please use NvidiaLLMService from pipecat.services.nvidia.tts instead.",
DeprecationWarning,
stacklevel=2,
)
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}")
RIVA_TTS_TIMEOUT_SECS = 5
class RivaTTSService(TTSService):
"""NVIDIA Riva text-to-speech service.
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
cloud-based TTS models. Supports multiple voices, languages, and
configurable quality settings.
"""
class InputParams(BaseModel):
"""Input parameters for Riva TTS configuration.
Parameters:
language: Language code for synthesis. Defaults to US English.
quality: Audio quality setting (0-100). Defaults to 20.
"""
language: Optional[Language] = Language.EN_US
quality: Optional[int] = 20
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
sample_rate: Optional[int] = None,
model_function_map: Mapping[str, str] = {
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
"model_name": "magpie-tts-multilingual",
},
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva TTS service.
Args:
api_key: NVIDIA API key for authentication.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for the TTS model.
params: Additional configuration parameters for TTS synthesis.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RivaTTSService.InputParams()
self._api_key = api_key
self._voice_id = voice_id
self._language_code = params.language
self._quality = params.quality
self._function_id = model_function_map.get("function_id")
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
config_response = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
async def set_model(self, model: str):
"""Attempt to set the TTS model.
Note: Model cannot be changed after initialization for Riva service.
Args:
model: The model name to set (operation not supported).
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech data.
"""
def read_audio_responses(queue: asyncio.Queue):
def add_response(r):
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
try:
responses = self._service.synthesize_online(
text,
self._voice_id,
self._language_code,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._quality,
custom_dictionary={},
)
for r in responses:
add_response(r)
add_response(None)
except Exception as e:
logger.error(f"{self} exception: {e}")
add_response(None)
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
try:
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)
# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
except asyncio.TimeoutError:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
class FastPitchTTSService(RivaTTSService):
"""Deprecated FastPitch TTS service.
.. deprecated:: 0.0.66
This class is deprecated. Use RivaTTSService instead for new implementations.
Provides backward compatibility for existing FastPitch TTS integrations.
"""
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "English-US.Female-1",
sample_rate: Optional[int] = None,
model_function_map: Mapping[str, str] = {
"function_id": "0149dedb-2be8-4195-b9a0-e57e0e14f972",
"model_name": "fastpitch-hifigan-tts",
},
params: Optional[RivaTTSService.InputParams] = None,
**kwargs,
):
"""Initialize the deprecated FastPitch TTS service.
Args:
api_key: NVIDIA API key for authentication.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
voice_id: Voice model identifier. Defaults to Female-1 voice.
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for FastPitch model.
params: Additional configuration parameters for TTS synthesis.
**kwargs: Additional arguments passed to parent RivaTTSService.
"""
super().__init__(
api_key=api_key,
server=server,
voice_id=voice_id,
sample_rate=sample_rate,
model_function_map=model_function_map,
params=params,
**kwargs,
)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.",
DeprecationWarning,
)
RivaTTSService = NvidiaTTSService
FastPitchTTSService = NvidiaTTSService
RIVA_TTS_TIMEOUT_SECS = NVIDIA_TTS_TIMEOUT_SECS

View File

@@ -514,9 +514,11 @@ class SarvamTTSService(InterruptibleTTSService):
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame and flush audio if it's the end of a full response."""
if isinstance(frame, LLMFullResponseEndFrame):
await super().process_frame(frame, direction)
# When the LLM finishes responding, flush any remaining text in Sarvam's buffer
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
return await super().process_frame(frame, direction)
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""

View File

@@ -425,16 +425,13 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
pending_aggregation = self._text_aggregator.text
# Flush any remaining text (including text waiting for lookahead)
remaining = await self._text_aggregator.flush()
if remaining:
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
# Reset aggregator state
await self._text_aggregator.reset()
self._processing_text = False
if pending_aggregation.text:
await self._push_tts_frames(
AggregatedTextFrame(pending_aggregation.text, pending_aggregation.type)
)
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
@@ -539,17 +536,20 @@ class TTSService(AIService):
text = frame.text
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
aggregated_by = "token"
if text:
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
)
else:
aggregate = await self._text_aggregator.aggregate(frame.text)
if aggregate:
async for aggregate in self._text_aggregator.aggregate(frame.text):
text = aggregate.text
aggregated_by = aggregate.type
if text:
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
)
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
)
async def _push_tts_frames(
self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False

View File

@@ -12,7 +12,7 @@ from typing import Awaitable, Callable, Optional
import websockets
from loguru import logger
from websockets.exceptions import ConnectionClosedOK
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from websockets.protocol import State
from pipecat.frames.frames import ErrorFrame
@@ -137,6 +137,10 @@ class WebsocketService(ABC):
# Normal closure, don't retry
logger.debug(f"{self} connection closed normally: {e}")
break
except ConnectionClosedError as e:
# Error closure, don't retry
logger.warning(f"{self} connection closed, but with an error: {e}")
break
except Exception as e:
message = f"{self} error receiving messages: {e}"
logger.error(message)

View File

@@ -6,31 +6,14 @@
"""Base notifier interface for Pipecat."""
from abc import ABC, abstractmethod
import warnings
from pipecat.utils.sync.base_notifier import BaseNotifier
class BaseNotifier(ABC):
"""Abstract base class for notification mechanisms.
Provides a standard interface for implementing notification and waiting
patterns used for event coordination and signaling between components
in the Pipecat framework.
"""
@abstractmethod
async def notify(self):
"""Send a notification signal.
Implementations should trigger any waiting coroutines or processes
that are blocked on this notifier.
"""
pass
@abstractmethod
async def wait(self):
"""Wait for a notification signal.
Implementations should block until a notification is received
from the corresponding notify() call.
"""
pass
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -6,40 +6,14 @@
"""Event-based notifier implementation using asyncio Event primitives."""
import asyncio
import warnings
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
class EventNotifier(BaseNotifier):
"""Event-based notifier using asyncio.Event for task synchronization.
Provides a simple notification mechanism where one task can signal
an event and other tasks can wait for that event to occur. The event
is automatically cleared after each wait operation.
"""
def __init__(self):
"""Initialize the event notifier.
Creates an internal asyncio.Event for managing notifications.
"""
self._event = asyncio.Event()
async def notify(self):
"""Signal the event to notify waiting tasks.
Sets the internal event, causing any tasks waiting on this
notifier to be awakened.
"""
self._event.set()
async def wait(self):
"""Wait for the event to be signaled.
Blocks until another task calls notify(). Automatically clears
the event after being awakened so subsequent calls will wait
for the next notification.
"""
await self._event.wait()
self._event.clear()
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -12,6 +12,7 @@ comprehensive monitoring and cleanup capabilities.
"""
import asyncio
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Coroutine, Dict, Optional, Sequence
@@ -162,7 +163,9 @@ class TaskManager(BaseTaskManager):
# Re-raise the exception to ensure the task is cancelled.
raise
except Exception as e:
logger.error(f"{name}: unexpected exception: {e}")
tb = traceback.extract_tb(e.__traceback__)
last = tb[-1]
logger.error(f"{name} unexpected exception ({last.filename}:{last.lineno}): {e}")
if not self._params:
raise Exception("TaskManager is not setup: unable to get event loop")
@@ -197,9 +200,17 @@ class TaskManager(BaseTaskManager):
# Here are sure the task is cancelled properly.
pass
except Exception as e:
logger.error(f"{name}: unexpected exception while cancelling task: {e}")
tb = traceback.extract_tb(e.__traceback__)
last = tb[-1]
logger.error(
f"{name} unexpected exception while cancelling task ({last.filename}:{last.lineno}): {e}"
)
except BaseException as e:
logger.critical(f"{name}: fatal base exception while cancelling task: {e}")
tb = traceback.extract_tb(e.__traceback__)
last = tb[-1]
logger.critical(
f"{name} fatal base exception while cancelling task ({last.filename}:{last.lineno}): {e}"
)
raise
def current_tasks(self) -> Sequence[asyncio.Task]:

View File

@@ -203,7 +203,7 @@ def parse_start_end_tags(
class TextPartForConcatenation:
"""Class representing a part of text for concatenation with concatenate_aggregated_text.
Attributes:
Parameters:
text: The text content.
includes_inter_part_spaces: Whether any necessary inter-frame
(leading/trailing) spaces are already included in the text.

View File

View File

@@ -0,0 +1,36 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base notifier interface for Pipecat."""
from abc import ABC, abstractmethod
class BaseNotifier(ABC):
"""Abstract base class for notification mechanisms.
Provides a standard interface for implementing notification and waiting
patterns used for event coordination and signaling between components
in the Pipecat framework.
"""
@abstractmethod
async def notify(self):
"""Send a notification signal.
Implementations should trigger any waiting coroutines or processes
that are blocked on this notifier.
"""
pass
@abstractmethod
async def wait(self):
"""Wait for a notification signal.
Implementations should block until a notification is received
from the corresponding notify() call.
"""
pass

View File

@@ -0,0 +1,45 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Event-based notifier implementation using asyncio Event primitives."""
import asyncio
from pipecat.utils.sync.base_notifier import BaseNotifier
class EventNotifier(BaseNotifier):
"""Event-based notifier using asyncio.Event for task synchronization.
Provides a simple notification mechanism where one task can signal
an event and other tasks can wait for that event to occur. The event
is automatically cleared after each wait operation.
"""
def __init__(self):
"""Initialize the event notifier.
Creates an internal asyncio.Event for managing notifications.
"""
self._event = asyncio.Event()
async def notify(self):
"""Signal the event to notify waiting tasks.
Sets the internal event, causing any tasks waiting on this
notifier to be awakened.
"""
self._event.set()
async def wait(self):
"""Wait for the event to be signaled.
Blocks until another task calls notify(). Automatically clears
the event after being awakened so subsequent calls will wait
for the next notification.
"""
await self._event.wait()
self._event.clear()

View File

@@ -14,7 +14,7 @@ aggregated text should be sent for speech synthesis.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import AsyncIterator, Optional
class AggregationType(str, Enum):
@@ -80,33 +80,43 @@ class BaseTextAggregator(ABC):
pass
@abstractmethod
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate the specified text with the currently accumulated text.
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
"""Aggregate the specified text and yield completed aggregations.
This method should be implemented to define how the new text contributes
to the aggregation process. It returns the aggregated text and a string
describing how it was aggregated if it's ready to be processed,
or None otherwise.
This method processes the input text character-by-character internally
and yields Aggregation objects as they complete.
Subclasses should implement their specific logic for:
- How to combine new text with existing accumulated text
- How to process text character-by-character
- When to consider the aggregated text ready for processing
- What criteria determine text completion (e.g., sentence boundaries)
- When a completion occurs, the method should return an Aggregation object
containing the aggregated text and its type. The text should be stripped
of leading/trailing whitespace so that consumers can rely on a consistent
format.
- When a completion occurs, yield an Aggregation object containing the
aggregated text (stripped of leading/trailing whitespace) and its type
Args:
text: The text to be aggregated.
Yields:
Aggregation objects as they complete. Each Aggregation consists of
the aggregated text (stripped of leading/trailing whitespace) and
a string indicating the type of aggregation (e.g., 'sentence', 'word',
'token', 'my_custom_aggregation').
"""
pass
# Make this a generator to satisfy type checker
yield # pragma: no cover
@abstractmethod
async def flush(self) -> Optional[Aggregation]:
"""Flush any pending aggregation.
This method is called at the end of a stream (e.g., when receiving
LLMFullResponseEndFrame) to return any text that was buffered.
Returns:
An Aggregation object if ready for processing, or None if more
text is needed before the aggregated content is ready. If an Aggregation
object is returned, it should consist of the updated aggregated text,
stripped of leading/trailing whitespace, and a string indicating the
type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
An Aggregation object if there is pending text, or None if there
is no pending text.
"""
pass

View File

@@ -13,12 +13,12 @@ support for custom handlers and configurable actions for when a pattern is found
import re
from enum import Enum
from typing import Awaitable, Callable, List, Optional, Tuple
from typing import AsyncIterator, Awaitable, Callable, List, Optional, Tuple
from loguru import logger
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
class MatchAction(Enum):
@@ -26,15 +26,15 @@ class MatchAction(Enum):
Parameters:
REMOVE: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
Sentence aggregation will continue on as if this text did not exist.
KEEP: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
Sentence aggregation will continue on with the internal text included.
AGGREGATE: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
"""
REMOVE = "remove"
@@ -72,7 +72,7 @@ class PatternMatch(Aggregation):
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
class PatternPairAggregator(BaseTextAggregator):
class PatternPairAggregator(SimpleTextAggregator):
"""Aggregator that identifies and processes content between pattern pairs.
This aggregator buffers text until it can identify complete pattern pairs
@@ -97,9 +97,10 @@ class PatternPairAggregator(BaseTextAggregator):
Creates an empty aggregator with no patterns or handlers registered.
Text buffering and pattern detection will begin when text is aggregated.
"""
self._text = ""
super().__init__()
self._patterns = {}
self._handlers = {}
self._last_processed_position = 0 # Track where we last checked for complete patterns
@property
def text(self) -> Aggregation:
@@ -132,17 +133,15 @@ class PatternPairAggregator(BaseTextAggregator):
Args:
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
those are reserved for the default behavior.
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
those are reserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
action: What to do when a complete pattern is matched:
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
normal text. This allows you to register handlers for
the pattern without affecting the aggregation logic.
- MatchAction.AGGREGATE: Return the matched pattern as a separate
aggregation object.
action: What to do when a complete pattern is matched.
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as normal text. This allows you to register handlers for the pattern without affecting the aggregation logic.
- MatchAction.AGGREGATE: Return the matched pattern as a separate aggregation object.
Returns:
Self for method chaining.
@@ -218,14 +217,18 @@ class PatternPairAggregator(BaseTextAggregator):
self._handlers[type] = handler
return self
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
"""Process all complete pattern pairs in the text.
async def _process_complete_patterns(
self, text: str, last_processed_position: int = 0
) -> Tuple[List[PatternMatch], str]:
"""Process newly complete pattern pairs in the text.
Searches for all complete pattern pairs in the text, calls the
appropriate handlers, and optionally removes the matches.
Searches for pattern pairs that have been completed since last_processed_position,
calls the appropriate handlers, and optionally removes the matches.
Args:
text: The text to process for pattern matches.
last_processed_position: The position in text that was already processed.
Only patterns that end at or after this position will be processed.
Returns:
Tuple of (all_matches, processed_text) where:
@@ -259,17 +262,23 @@ class PatternPairAggregator(BaseTextAggregator):
content=content.strip(), type=type, full_match=full_match
)
# Call the appropriate handler if registered
if type in self._handlers:
# Check if this pattern was already processed
already_processed = match.end() <= last_processed_position
# Only call handler for newly completed patterns
if not already_processed and type in self._handlers:
try:
await self._handlers[type](pattern_match)
except Exception as e:
logger.error(f"Error in pattern handler for {type}: {e}")
# Remove the pattern from the text if configured
# Handle pattern based on action
if action == MatchAction.REMOVE:
processed_text = processed_text.replace(full_match, "", 1)
# Remove patterns are only removed once (when newly completed)
if not already_processed:
processed_text = processed_text.replace(full_match, "", 1)
else:
# KEEP/AGGREGATE patterns stay in all_matches
all_matches.append(pattern_match)
return all_matches, processed_text
@@ -305,76 +314,84 @@ class PatternPairAggregator(BaseTextAggregator):
return None
async def aggregate(self, text: str) -> Optional[PatternMatch]:
async def aggregate(self, text: str) -> AsyncIterator[PatternMatch]:
"""Aggregate text and process pattern pairs.
This method adds the new text to the buffer, processes any complete pattern
pairs, and returns processed text up to sentence boundaries if possible.
If there are incomplete patterns (start without matching end), it will
continue buffering text.
Processes the input text character-by-character, handles pattern pairs,
and uses the parent's lookahead logic for sentence detection when no
patterns are active.
Args:
text: New text to add to the buffer.
text: Text to aggregate.
Returns:
Processed text up to a sentence boundary, or None if more
text is needed to form a complete sentence or pattern.
Yields:
PatternMatch objects as patterns complete or sentences are detected.
"""
# Add new text to buffer
self._text += text
# Process text character by character
for char in text:
self._text += char
# Process any complete patterns in the buffer
patterns, processed_text = await self._process_complete_patterns(self._text)
# Process any newly complete patterns in the buffer
# Only patterns that complete after _last_processed_position will trigger handlers
patterns, processed_text = await self._process_complete_patterns(
self._text, self._last_processed_position
)
self._text = processed_text
# Update the last processed position to prevent re-processing patterns
# This tracks where in the buffer we've already called handlers, so we
# only trigger handlers once when a pattern completes
self._last_processed_position = len(self._text)
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
self._text = processed_text
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
yield patterns[0]
continue
# Check if we have incomplete patterns
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, continue
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
continue
# For AGGREGATE patterns: yield any text before the pattern starts
# This ensures text doesn't get stuck in the buffer waiting for sentence
# boundaries when a pattern begins (e.g., "Here is code <code>..." yields "Here is code")
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
yield PatternMatch(
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
return patterns[0]
continue
# Check if we have incomplete patterns
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
return None
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
return PatternMatch(
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
)
# Find sentence boundary if no incomplete patterns
eos_marker = match_endofsentence(self._text)
if eos_marker:
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return PatternMatch(
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
)
# No complete sentence found yet
return None
# Use parent's lookahead logic for sentence detection
aggregation = await super()._check_sentence_with_lookahead(char)
if aggregation:
# Convert to PatternMatch for consistency with return type
yield PatternMatch(
content=aggregation.text, type=aggregation.type, full_match=aggregation.text
)
async def handle_interruption(self):
"""Handle interruptions by clearing the buffer.
"""Handle interruptions by clearing the buffer and pattern state.
Called when an interruption occurs in the processing pipeline,
to reset the state and discard any partially aggregated text.
"""
self._text = ""
await super().handle_interruption()
self._last_processed_position = 0
# Pattern and handler state persists across interruptions
async def reset(self):
"""Clear the internally aggregated text.
@@ -382,4 +399,6 @@ class PatternPairAggregator(BaseTextAggregator):
Resets the aggregator to its initial state, discarding any
buffered text and clearing pattern tracking state.
"""
self._text = ""
await super().reset()
self._last_processed_position = 0
# Pattern and handler state persists across resets

View File

@@ -11,9 +11,9 @@ until it finds an end-of-sentence marker, making it suitable for basic TTS
text processing scenarios.
"""
from typing import Optional
from typing import AsyncIterator, Optional
from pipecat.utils.string import match_endofsentence
from pipecat.utils.string import SENTENCE_ENDING_PUNCTUATION, match_endofsentence
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
@@ -31,6 +31,7 @@ class SimpleTextAggregator(BaseTextAggregator):
Creates an empty text buffer ready to begin accumulating text tokens.
"""
self._text = ""
self._needs_lookahead: bool = False
@property
def text(self) -> Aggregation:
@@ -41,30 +42,87 @@ class SimpleTextAggregator(BaseTextAggregator):
"""
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text and return completed sentences.
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
"""Aggregate text and yield completed sentences.
Adds the new text to the buffer and checks for end-of-sentence markers.
When a sentence boundary is found, returns the completed sentence and
removes it from the buffer.
Processes the input text character-by-character. When sentence-ending
punctuation is detected, it waits for non-whitespace lookahead before
calling NLTK. This prevents false positives like "$29." being detected
as a sentence when it's actually "$29.95".
Args:
text: New text to add to the aggregation buffer.
text: Text to aggregate.
Yields:
Complete sentences as Aggregation objects.
"""
# Process text character by character
for char in text:
self._text += char
# Check for sentence with lookahead
result = await self._check_sentence_with_lookahead(char)
if result:
yield result
async def _check_sentence_with_lookahead(self, char: str) -> Optional[Aggregation]:
"""Check for sentence boundaries using lookahead logic.
This method implements the core sentence detection logic with lookahead.
When sentence-ending punctuation is detected, it waits for the next
non-whitespace character before calling NLTK. This disambiguates cases
like "$29." (not a sentence) vs "$29. Next" (sentence ends at period).
Whitespace alone is not meaningful lookahead since it appears in both
cases. Instead, the first non-whitespace character after the punctuation
is used to confirm the sentence boundary.
Subclasses can call this via super() to reuse the lookahead behavior
while adding their own logic (e.g., tag handling, pattern matching).
Args:
char: The most recently added character (used for lookahead check).
Returns:
A complete sentence if an end-of-sentence marker is found,
or None if more text is needed to complete a sentence.
Aggregation if sentence found, None otherwise.
"""
result: Optional[str] = None
# If we need lookahead, check if we now have non-whitespace
if self._needs_lookahead:
# Check if the new character is non-whitespace
if char.strip():
# We have meaningful lookahead, call NLTK
self._needs_lookahead = False
eos_marker = match_endofsentence(self._text)
self._text += text
if eos_marker:
# NLTK confirmed a sentence - return it
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return Aggregation(text=result, type=AggregationType.SENTENCE)
# No sentence found - keep accumulating
return None
# Still whitespace, keep waiting
return None
eos_end_marker = match_endofsentence(self._text)
if eos_end_marker:
result = self._text[:eos_end_marker]
self._text = self._text[eos_end_marker:]
# Check if we just added sentence-ending punctuation
if self._text and self._text[-1] in SENTENCE_ENDING_PUNCTUATION:
# Mark that we need lookahead (don't call NLTK yet)
self._needs_lookahead = True
if result:
return None
async def flush(self) -> Optional[Aggregation]:
"""Flush any remaining text in the buffer.
Returns any text remaining in the buffer. This is called at the end
of a stream to ensure all text is processed.
Returns:
Any remaining text as a sentence, or None if buffer is empty.
"""
if self._text:
# Return whatever we have in the buffer
result = self._text
await self.reset()
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
return None
@@ -75,6 +133,7 @@ class SimpleTextAggregator(BaseTextAggregator):
discarding any partially accumulated text.
"""
self._text = ""
self._needs_lookahead = False
async def reset(self):
"""Clear the internally aggregated text.
@@ -83,3 +142,4 @@ class SimpleTextAggregator(BaseTextAggregator):
any accumulated text content.
"""
self._text = ""
self._needs_lookahead = False

View File

@@ -11,13 +11,14 @@ between specified start/end tag pairs, ensuring that tagged content is processed
as a unit regardless of internal punctuation.
"""
from typing import Optional, Sequence
from typing import AsyncIterator, Optional, Sequence
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
from pipecat.utils.string import StartEndTags, parse_start_end_tags
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
class SkipTagsAggregator(BaseTextAggregator):
class SkipTagsAggregator(SimpleTextAggregator):
"""Aggregator that prevents end of sentence matching between start/end tags.
This aggregator buffers text until it finds an end of sentence or a start
@@ -37,67 +38,59 @@ class SkipTagsAggregator(BaseTextAggregator):
tags: Sequence of StartEndTags objects defining the tag pairs
that should prevent sentence boundary detection.
"""
self._text = ""
super().__init__()
self._tags = tags
self._current_tag: Optional[StartEndTags] = None
self._current_tag_index: int = 0
@property
def text(self) -> Aggregation:
"""Get the currently buffered text.
Returns:
The current text buffer content that hasn't been processed yet.
"""
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[Aggregation]:
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
"""Aggregate text while respecting tag boundaries.
This method adds the new text to the buffer, processes any complete
pattern pairs, and returns processed text up to sentence boundaries if
possible. If there are incomplete patterns (start without matching
end), it will continue buffering text.
Processes the input text character-by-character, updates tag state, and
uses the parent's lookahead logic for sentence detection when not
inside tags.
Args:
text: New text to add to the buffer.
text: Text to aggregate.
Returns:
An Aggregation object containing text up to a sentence boundary and
marked as SENTENCE type or None if more text is needed to complete a
sentence or close tags.
Yields:
Aggregation objects containing text up to a sentence boundary,
marked as SENTENCE type.
"""
# Add new text to buffer
self._text += text
# Process text character by character
for char in text:
self._text += char
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
self._text, self._tags, self._current_tag, self._current_tag_index
)
# Update tag state
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
self._text, self._tags, self._current_tag, self._current_tag_index
)
# Find sentence boundary if no incomplete patterns
if not self._current_tag:
eos_marker = match_endofsentence(self._text)
if eos_marker:
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
# If inside tags, don't check for sentences
if self._current_tag:
continue
# No complete sentence found yet
return None
# Otherwise, use parent's lookahead logic for sentence detection
result = await super()._check_sentence_with_lookahead(char)
if result:
yield result
async def handle_interruption(self):
"""Handle interruptions by clearing the buffer.
"""Handle interruptions by clearing the buffer and tag state.
Called when an interruption occurs in the processing pipeline,
to reset the state and discard any partially aggregated text.
"""
self._text = ""
await super().handle_interruption()
self._current_tag = None
self._current_tag_index = 0
async def reset(self):
"""Clear the internally aggregated text.
"""Clear the internally aggregated text and tag state.
Resets the aggregator to its initial state, discarding any
buffered text.
"""
self._text = ""
await super().reset()
self._current_tag = None
self._current_tag_index = 0

View File

@@ -483,7 +483,9 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
# Add all available attributes to the span
attribute_kwargs = {
"service_name": service_class_name,
"model": getattr(self, "model_name", "unknown"),
"model": getattr(
self, getattr(self, "_full_model_name", "model_name"), "unknown"
),
"stream": True, # Most LLM services use streaming
"parameters": params,
}

View File

@@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.extensions.ivr.ivr_navigator import IVRProcessor
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMMessagesUpdateFrame,
LLMTextFrame,
OutputDTMFUrgentFrame,
@@ -334,10 +335,12 @@ class TestIVRNavigation(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
LLMTextFrame(text="Hello, I'm trying to reach billing."),
LLMFullResponseEndFrame(),
]
expected_down_frames = [
LLMTextFrame, # Should pass through unchanged
LLMFullResponseEndFrame,
]
expected_up_frames = [

View File

@@ -38,14 +38,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
async def test_pattern_match_and_removal(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</test>!")
text = "Hello <test>pattern content</test>!"
results = [result async for result in self.aggregator.aggregate(text)]
# Verify the handler was called with correct PatternMatch object
self.test_handler.assert_called_once()
@@ -55,28 +49,37 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
self.assertEqual(call_args.text, "pattern content")
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
self.assertEqual(result.text, "Hello !")
self.assertEqual(result.type, "sentence")
# No results yet (waiting for lookahead after "!")
self.assertEqual(len(results), 0)
# Next sentence should be processed separately. Spaces around the sentence
# should be stripped in the returned Aggregation.
result = await self.aggregator.aggregate(" This is another sentence.")
# Next sentence should provide the lookahead and trigger the previous sentence
async for result in self.aggregator.aggregate(" This is another sentence."):
results.append(result)
# First result should be "Hello !" triggered by the space lookahead
self.assertEqual(len(results), 1)
self.assertEqual(results[0].text, "Hello !")
self.assertEqual(results[0].type, "sentence")
# Now flush to get the remaining sentence
result = await self.aggregator.flush()
self.assertEqual(result.text, "This is another sentence.")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
async def test_pattern_match_and_aggregate(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Here is code <code>pattern")
self.assertEqual(result.text, "Here is code")
self.assertEqual(self.aggregator.text.text, "<code>pattern")
self.assertEqual(self.aggregator.text.type, "code_pattern")
text = "Here is code <code>pattern content</code> This is another sentence."
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</code>")
results = [result async for result in self.aggregator.aggregate(text)]
# First result should be "Here is code" when pattern starts
self.assertEqual(results[0].text, "Here is code")
self.assertEqual(results[0].type, "sentence")
# Second result should be the code pattern content
self.assertEqual(results[1].text, "pattern content")
self.assertEqual(results[1].type, "code_pattern")
# Verify the handler was called with correct PatternMatch object
self.code_handler.assert_called_once()
@@ -85,11 +88,9 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(call_args.type, "code_pattern")
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(result.text, "pattern content")
self.assertEqual(result.type, "code_pattern")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
# Last sentence needs flush (waiting for lookahead after ".")
result = await self.aggregator.flush()
self.assertEqual(result.text, "This is another sentence.")
self.assertEqual(result.type, "sentence")
@@ -97,11 +98,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(self.aggregator.text.text, "")
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
result = await self.aggregator.aggregate("Hello <test>pattern content")
text = "Hello <test>pattern content"
results = [result async for result in self.aggregator.aggregate(text)]
# No complete pattern yet, so nothing should be returned
self.assertIsNone(result)
self.assertEqual(len(results), 0)
# The handler should not be called yet
self.test_handler.assert_not_called()
@@ -136,9 +136,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.aggregator.on_pattern_match("voice", voice_handler)
self.aggregator.on_pattern_match("emphasis", emphasis_handler)
# Test with multiple patterns in one text block
text = "Hello <voice>female</voice> I am <em>very</em> excited to meet you!"
result = await self.aggregator.aggregate(text)
results = [result async for result in self.aggregator.aggregate(text)]
# Both handlers should be called with correct data
voice_handler.assert_called_once()
@@ -151,6 +150,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(emphasis_match.type, "emphasis")
self.assertEqual(emphasis_match.text, "very")
# With lookahead, we need to flush to get the final sentence
self.assertEqual(len(results), 0) # Waiting for lookahead after "!"
result = await self.aggregator.flush()
# Voice pattern should be removed, emphasis pattern should remain
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
@@ -158,9 +161,9 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(self.aggregator.text.text, "")
async def test_handle_interruption(self):
# Start with incomplete pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
text = "Hello <test>pattern"
results = [result async for result in self.aggregator.aggregate(text)]
self.assertEqual(len(results), 0)
# Simulate interruption
await self.aggregator.handle_interruption()
@@ -172,20 +175,18 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_not_called()
async def test_pattern_across_sentences(self):
# Test pattern that spans multiple sentences
result = await self.aggregator.aggregate("Hello <test>This is sentence one.")
# First sentence contains start of pattern but no end, so no complete pattern yet
self.assertIsNone(result)
# Add second part with pattern end
result = await self.aggregator.aggregate(" This is sentence two.</test> Final sentence.")
text = "Hello <test>This is sentence one. This is sentence two.</test> Final sentence."
results = [result async for result in self.aggregator.aggregate(text)]
# Handler should be called with entire content
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
# With lookahead, we need to flush to get the final sentence
self.assertEqual(len(results), 0) # Waiting for lookahead after "."
result = await self.aggregator.flush()
# Pattern should be removed, resulting in text with sentences merged
self.assertEqual(result.text, "Hello Final sentence.")

View File

@@ -14,22 +14,112 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
self.aggregator = SimpleTextAggregator()
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
text = "Hello "
results = [agg async for agg in self.aggregator.aggregate(text)]
# No complete sentences yet
assert len(results) == 0
assert self.aggregator.text.text == "Hello"
await self.aggregator.reset()
assert self.aggregator.text.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
aggregate = await self.aggregator.aggregate("Pipecat!")
text = "Hello Pipecat!"
results = [agg async for agg in self.aggregator.aggregate(text)]
# No complete sentences yet (waiting for lookahead after "!")
assert len(results) == 0
# Flush to get the pending sentence
aggregate = await self.aggregator.flush()
assert aggregate.text == "Hello Pipecat!"
assert aggregate.type == "sentence"
assert self.aggregator.text.text == ""
async def test_multiple_sentences(self):
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
assert aggregate.text == "Hello Pipecat!"
# Aggregators should strip leading/trailing spaces when returning text
assert self.aggregator.text.text == "How are"
aggregate = await self.aggregator.aggregate("you?")
assert aggregate.text == "How are you?"
text = "Hello Pipecat! How are you?"
results = [agg async for agg in self.aggregator.aggregate(text)]
# First sentence should be complete (lookahead from "H" confirmed it)
assert len(results) == 1
assert results[0].text == "Hello Pipecat!"
# Flush to get the pending sentence
result = await self.aggregator.flush()
assert result.text == "How are you?"
async def test_lookahead_decimal_number(self):
"""Test that $29.95 is not split at $29."""
text = "Ask me for only $29.95/month."
results = [agg async for agg in self.aggregator.aggregate(text)]
# No complete sentences yet (waiting for lookahead after final ".")
assert len(results) == 0
# Can use flush() to get the pending sentence at end of stream
result = await self.aggregator.flush()
assert result.text == "Ask me for only $29.95/month."
async def test_lookahead_abbreviation(self):
"""Test that Mr. Smith is not split at Mr."""
text = "Hello Mr. Smith."
results = [agg async for agg in self.aggregator.aggregate(text)]
# No complete sentences yet (waiting for lookahead after final ".")
assert len(results) == 0
# Can use flush() to get the pending sentence at end of stream
result = await self.aggregator.flush()
assert result.text == "Hello Mr. Smith."
async def test_lookahead_actual_sentence_end(self):
"""Test that a real sentence end is detected after lookahead."""
text = "Hello world. Next sentence"
results = [agg async for agg in self.aggregator.aggregate(text)]
# First sentence should be complete (lookahead from "N" confirmed it)
assert len(results) == 1
assert results[0].text == "Hello world."
async def test_flush_pending_sentence(self):
"""Test that flush() returns pending sentence waiting for lookahead."""
text = "Hello world."
results = [agg async for agg in self.aggregator.aggregate(text)]
# No complete sentences yet (waiting for lookahead)
assert len(results) == 0
# Call flush to get it
result = await self.aggregator.flush()
assert result is not None
assert result.text == "Hello world."
# Flush again should return None
assert await self.aggregator.flush() == None
async def test_flush_with_no_pending(self):
"""Test that flush() returns any remaining text in buffer."""
text = "Hello"
results = [agg async for agg in self.aggregator.aggregate(text)]
# No complete sentences
assert len(results) == 0
result = await self.aggregator.flush()
# flush() now returns any remaining text, not just pending lookahead
assert result is not None
assert result.text == "Hello"
# Buffer should be empty after flush
assert self.aggregator.text.text == ""
async def test_flush_after_lookahead_confirmed(self):
"""Test flush after lookahead has already confirmed sentence."""
text = "Hello. W"
results = [agg async for agg in self.aggregator.aggregate(text)]
# First sentence should be complete (lookahead from "W" confirmed it)
assert len(results) == 1
assert results[0].text == "Hello."
# flush() returns any remaining text (the "W" in this case)
result = await self.aggregator.flush()
assert result.text == "W"

View File

@@ -17,7 +17,14 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
await self.aggregator.reset()
# No tags involved, aggregate at end of sentence.
result = await self.aggregator.aggregate("Hello Pipecat!")
text = "Hello Pipecat!"
results = [agg async for agg in self.aggregator.aggregate(text)]
# Should still be waiting for lookahead after "!"
self.assertEqual(len(results), 0)
# Flush to get the pending sentence
result = await self.aggregator.flush()
self.assertEqual(result.text, "Hello Pipecat!")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
@@ -26,7 +33,14 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
await self.aggregator.reset()
# Tags involved, avoid aggregation during tags.
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
text = "My email is <spell>foo@pipecat.ai</spell>."
results = [agg async for agg in self.aggregator.aggregate(text)]
# Should still be waiting for lookahead after "."
self.assertEqual(len(results), 0)
# Flush to get the pending sentence
result = await self.aggregator.flush()
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
@@ -34,25 +48,17 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
async def test_streaming_tags(self):
await self.aggregator.reset()
# Tags involved, stream small chunk of texts.
result = await self.aggregator.aggregate("My email is <sp")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <sp")
# Tags involved
text = "My email is <spell>foo.bar@pipecat.ai</spell>."
results = [agg async for agg in self.aggregator.aggregate(text)]
result = await self.aggregator.aggregate("ell>foo.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
result = await self.aggregator.aggregate("bar@pipecat.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
result = await self.aggregator.aggregate("ai</spe")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
# Should still be waiting for lookahead after "."
self.assertEqual(len(results), 0)
self.assertEqual(self.aggregator.text.text, text)
self.assertEqual(self.aggregator.text.type, "sentence")
result = await self.aggregator.aggregate("ll>.")
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
# Flush to get the pending sentence
result = await self.aggregator.flush()
self.assertEqual(result.text, text)
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text.type, "sentence")

34
uv.lock generated
View File

@@ -36,12 +36,12 @@ wheels = [
[[package]]
name = "aic-sdk"
version = "1.1.0"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/99/83/bf38b95d98c67b8ebc574fb4a4f23c07a3740b51992d7522976173d30b98/aic_sdk-1.1.0.tar.gz", hash = "sha256:04e08df695581c8cb4db8acca20e73815e9f449e7bd08e0162fd55518c727963", size = 34954, upload-time = "2025-11-11T20:45:24.25Z" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/ba/3ebe31b91e03d42437ec864e9d2af3a52b7ccc73a1a0c1026275956270b0/aic_sdk-1.2.0.tar.gz", hash = "sha256:eeda9a181c679f175dbe6f0efc0c67ec98ff3d84cfe01541fef7fa12ecd505ca", size = 35606, upload-time = "2025-11-20T14:42:14.333Z" }
[[package]]
name = "aioboto3"
@@ -4496,6 +4496,9 @@ google = [
{ name = "google-genai" },
{ name = "websockets" },
]
gradium = [
{ name = "websockets" },
]
groq = [
{ name = "groq" },
]
@@ -4564,6 +4567,9 @@ neuphonic = [
noisereduce = [
{ name = "noisereduce" },
]
nvidia = [
{ name = "nvidia-riva-client" },
]
openai = [
{ name = "websockets" },
]
@@ -4652,6 +4658,7 @@ dev = [
{ name = "ruff" },
{ name = "setuptools" },
{ name = "setuptools-scm" },
{ name = "towncrier" },
]
docs = [
{ name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
@@ -4666,7 +4673,7 @@ docs = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'moondream'", specifier = "~=1.10.0" },
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.1.0" },
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.2.0" },
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.5.0" },
{ name = "aiofiles", specifier = ">=24.1.0,<25" },
{ name = "aiohttp", specifier = ">=3.11.12,<4" },
@@ -4706,7 +4713,7 @@ requires-dist = [
{ name = "noisereduce", marker = "extra == 'noisereduce'", specifier = "~=3.0.3" },
{ name = "numba", specifier = "==0.61.2" },
{ name = "numpy", specifier = ">=1.26.4,<3" },
{ name = "nvidia-riva-client", marker = "extra == 'riva'", specifier = "~=2.21.1" },
{ name = "nvidia-riva-client", marker = "extra == 'nvidia'", specifier = "~=2.21.1" },
{ name = "onnxruntime", marker = "extra == 'local-smart-turn-v3'", specifier = ">=1.20.1,<2" },
{ name = "onnxruntime", marker = "extra == 'silero'", specifier = ">=1.20.1,<2" },
{ name = "openai", specifier = ">=1.74.0,<3" },
@@ -4717,6 +4724,7 @@ requires-dist = [
{ name = "opentelemetry-sdk", marker = "extra == 'tracing'", specifier = ">=1.33.0" },
{ name = "ormsgpack", marker = "extra == 'fish'", specifier = "~=1.7.0" },
{ name = "pillow", specifier = ">=11.1.0,<12" },
{ name = "pipecat-ai", extras = ["nvidia"], marker = "extra == 'riva'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'assemblyai'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'asyncai'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'aws'" },
@@ -4726,6 +4734,7 @@ requires-dist = [
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'fish'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'gladia'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'google'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'gradium'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'heygen'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'lmnt'" },
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'neuphonic'" },
@@ -4767,7 +4776,7 @@ requires-dist = [
{ name = "wait-for2", marker = "python_full_version < '3.12'", specifier = ">=0.4.1" },
{ name = "websockets", marker = "extra == 'websockets-base'", specifier = ">=13.1,<16.0" },
]
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "nim", "noisereduce", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "gradium", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "noisereduce", "nvidia", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
[package.metadata.requires-dev]
dev = [
@@ -4784,6 +4793,7 @@ dev = [
{ name = "ruff", specifier = ">=0.12.11,<1" },
{ name = "setuptools", specifier = "~=78.1.1" },
{ name = "setuptools-scm", specifier = "~=8.3.1" },
{ name = "towncrier", specifier = "~=25.8.0" },
]
docs = [
{ name = "sphinx", specifier = ">=8.1.3" },
@@ -7243,6 +7253,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c1/7b/30d423bdb2546250d719d7821aaf9058cc093d165565b245b159c788a9dd/torchvision-0.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e5d680162694fac4c8a374954e261ddfb4eb0ce103287b0f693e4e9c579ef957", size = 1638621, upload-time = "2025-04-23T14:41:46.06Z" },
]
[[package]]
name = "towncrier"
version = "25.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "jinja2" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c2/eb/5bf25a34123698d3bbab39c5bc5375f8f8bcbcc5a136964ade66935b8b9d/towncrier-25.8.0.tar.gz", hash = "sha256:eef16d29f831ad57abb3ae32a0565739866219f1ebfbdd297d32894eb9940eb1", size = 76322, upload-time = "2025-08-30T11:41:55.393Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/42/06/8ba22ec32c74ac1be3baa26116e3c28bc0e76a5387476921d20b6fdade11/towncrier-25.8.0-py3-none-any.whl", hash = "sha256:b953d133d98f9aeae9084b56a3563fd2519dfc6ec33f61c9cd2c61ff243fb513", size = 65101, upload-time = "2025-08-30T11:41:53.644Z" },
]
[[package]]
name = "tqdm"
version = "4.67.1"