Compare commits
94 Commits
aleix/clau
...
v0.0.97
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cefe1357c | ||
|
|
4df0a9bf73 | ||
|
|
9ef139d020 | ||
|
|
9103d4ae05 | ||
|
|
bd63b6cefa | ||
|
|
4d03270bc3 | ||
|
|
6aee72c5b4 | ||
|
|
8d62cfb1b6 | ||
|
|
41214236ab | ||
|
|
b25963a63b | ||
|
|
8c6ef21d84 | ||
|
|
0ffaa09c95 | ||
|
|
f6e31b7e89 | ||
|
|
48422dd442 | ||
|
|
fed6a8b669 | ||
|
|
82e0253a62 | ||
|
|
a7f26dca60 | ||
|
|
459ef27f3f | ||
|
|
464cfa5ccb | ||
|
|
9289881a80 | ||
|
|
34033cd454 | ||
|
|
47c21c9579 | ||
|
|
3b0bcf0b66 | ||
|
|
c4a8308027 | ||
|
|
e9f76dcaf2 | ||
|
|
21b2229b2b | ||
|
|
11aa9c9e68 | ||
|
|
9f4680e9bd | ||
|
|
04443a3820 | ||
|
|
1571cc58ac | ||
|
|
dea80cf946 | ||
|
|
91dec044c4 | ||
|
|
8cf4267d87 | ||
|
|
0ee7cab6c6 | ||
|
|
74c2039bfb | ||
|
|
66088837cd | ||
|
|
07ebf8534a | ||
|
|
fce4cfba15 | ||
|
|
af52833ca0 | ||
|
|
9fdf756375 | ||
|
|
283bbb385c | ||
|
|
8c6b2edb25 | ||
|
|
6ab30f9b87 | ||
|
|
3d93285bdf | ||
|
|
7261cd28f2 | ||
|
|
33eeb8ce44 | ||
|
|
ebda94ca98 | ||
|
|
40b17cff8f | ||
|
|
7ba0ebba11 | ||
|
|
b39087027c | ||
|
|
e65974c870 | ||
|
|
b1e5d68d97 | ||
|
|
39bca074d7 | ||
|
|
b5e79f9dc5 | ||
|
|
613b96819f | ||
|
|
57c24670ea | ||
|
|
d79dd94019 | ||
|
|
fa8e7458e1 | ||
|
|
4d66191963 | ||
|
|
7e9d67002e | ||
|
|
ffbb6e5937 | ||
|
|
535b85cf90 | ||
|
|
8dc9872ed5 | ||
|
|
f37a53cc25 | ||
|
|
9cce28c64c | ||
|
|
3ca94363ec | ||
|
|
9dd882ecf8 | ||
|
|
0bbb14eb9b | ||
|
|
050f287ec4 | ||
|
|
e6f5561785 | ||
|
|
2df91f4b37 | ||
|
|
7db49b9067 | ||
|
|
7c497bdc89 | ||
|
|
1aa4247d2b | ||
|
|
1ffa9ff51f | ||
|
|
435b53f1a0 | ||
|
|
406bdfad0d | ||
|
|
acba544e6f | ||
|
|
5d93c64ee5 | ||
|
|
de10bc8803 | ||
|
|
36f5c1722d | ||
|
|
a8280522e5 | ||
|
|
05d65dfdd3 | ||
|
|
a3962e3b47 | ||
|
|
cd231cf829 | ||
|
|
9fafc1692d | ||
|
|
7648d0436c | ||
|
|
bff8747e38 | ||
|
|
d227c0c097 | ||
|
|
9ccde60521 | ||
|
|
b84a40666c | ||
|
|
e72b135a4c | ||
|
|
7961f8a664 | ||
|
|
4ca143e8af |
174
.github/workflows/generate-changelog.yml
vendored
Normal file
174
.github/workflows/generate-changelog.yml
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
name: Generate Changelog for Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "Release version (e.g., 0.0.97)"
|
||||
required: true
|
||||
type: string
|
||||
date:
|
||||
description: "Release date (YYYY-MM-DD format, defaults to today)"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
generate-changelog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Set release date
|
||||
id: set_date
|
||||
run: |
|
||||
if [ -z "${{ inputs.date }}" ]; then
|
||||
RELEASE_DATE=$(date +%Y-%m-%d)
|
||||
echo "Using today's date: $RELEASE_DATE"
|
||||
else
|
||||
RELEASE_DATE="${{ inputs.date }}"
|
||||
echo "Using provided date: $RELEASE_DATE"
|
||||
fi
|
||||
echo "release_date=$RELEASE_DATE" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Validate inputs
|
||||
run: |
|
||||
# Validate version format (basic check)
|
||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+.*$ ]]; then
|
||||
echo "Error: Version must be in format X.Y.Z (e.g., 0.0.97)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate date format if provided
|
||||
if [ -n "${{ inputs.date }}" ]; then
|
||||
if ! date -d "${{ inputs.date }}" >/dev/null 2>&1; then
|
||||
# Try macOS date format
|
||||
if ! date -j -f "%Y-%m-%d" "${{ inputs.date }}" >/dev/null 2>&1; then
|
||||
echo "Error: Date must be in YYYY-MM-DD format (e.g., 2025-12-04)"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Check for changelog fragments
|
||||
id: check_fragments
|
||||
run: |
|
||||
FRAGMENT_COUNT=$(find changelog -name "*.md" ! -name "_template.md.j2" | wc -l | tr -d ' ')
|
||||
echo "fragment_count=$FRAGMENT_COUNT" >> $GITHUB_OUTPUT
|
||||
|
||||
if [ "$FRAGMENT_COUNT" -eq "0" ]; then
|
||||
echo "❌ Error: No changelog fragments found in changelog/"
|
||||
echo ""
|
||||
echo "Cannot create a release without changelog entries."
|
||||
echo "Add changelog fragments to the changelog/ directory (e.g., 1234.added.md) and try again."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate fragment types
|
||||
VALID_TYPES="added changed deprecated removed fixed security"
|
||||
INVALID_FRAGMENTS=""
|
||||
|
||||
for file in changelog/*.md; do
|
||||
# Skip template
|
||||
if [[ "$file" == "changelog/_template.md.j2" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Extract type from filename (e.g., 1234.added.md -> added)
|
||||
filename=$(basename "$file")
|
||||
# Handle both 1234.added.md and 1234.added.2.md patterns
|
||||
type=$(echo "$filename" | sed -E 's/^[0-9]+\.([a-z]+)(\.[0-9]+)?\.md$/\1/')
|
||||
|
||||
# Check if type is valid
|
||||
if ! echo "$VALID_TYPES" | grep -wq "$type"; then
|
||||
INVALID_FRAGMENTS="$INVALID_FRAGMENTS\n - $filename (type: '$type')"
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -n "$INVALID_FRAGMENTS" ]; then
|
||||
echo "❌ Error: Invalid changelog fragment types found:"
|
||||
echo -e "$INVALID_FRAGMENTS"
|
||||
echo ""
|
||||
echo "Valid types are: $VALID_TYPES"
|
||||
echo "Example: 1234.added.md, 5678.fixed.md"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Found $FRAGMENT_COUNT changelog fragment(s)"
|
||||
echo "has_fragments=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Preview changelog
|
||||
run: |
|
||||
echo "## Preview of changelog for version ${{ inputs.version }}"
|
||||
echo ""
|
||||
uv run towncrier build --draft --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}"
|
||||
|
||||
- name: Build changelog
|
||||
run: |
|
||||
uv run towncrier build --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}" --yes
|
||||
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: "Update changelog for version ${{ inputs.version }}"
|
||||
title: "Release ${{ inputs.version }} - Changelog Update"
|
||||
body: |
|
||||
## Changelog Update for Release ${{ inputs.version }}
|
||||
|
||||
This PR updates the CHANGELOG.md with all changes for version **${{ inputs.version }}**.
|
||||
|
||||
### Summary
|
||||
- **Version:** ${{ inputs.version }}
|
||||
- **Date:** ${{ steps.set_date.outputs.release_date }}
|
||||
- **Fragments processed:** ${{ steps.check_fragments.outputs.fragment_count }}
|
||||
|
||||
### What this PR does
|
||||
- ✅ Adds new release section to CHANGELOG.md
|
||||
- ✅ Removes processed changelog fragments
|
||||
- ✅ Ready to merge for release
|
||||
|
||||
### Next Steps
|
||||
1. Review the changelog entries below
|
||||
2. Make any necessary edits to CHANGELOG.md if needed
|
||||
3. Merge this PR
|
||||
4. Continue with your release process
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>📋 Preview of changes</summary>
|
||||
|
||||
The changelog has been updated with entries from the following fragments:
|
||||
|
||||
```bash
|
||||
${{ steps.check_fragments.outputs.fragment_count }} fragments processed
|
||||
```
|
||||
|
||||
</details>
|
||||
branch: changelog-${{ inputs.version }}
|
||||
delete-branch: true
|
||||
labels: |
|
||||
changelog
|
||||
release
|
||||
96
CHANGELOG.md
96
CHANGELOG.md
@@ -5,25 +5,115 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
<!-- towncrier release notes start -->
|
||||
|
||||
## [0.0.97] - 2025-12-05
|
||||
|
||||
### Added
|
||||
|
||||
- Added new Gradium services, `GradiumSTTService` and `GradiumTTSService`, for
|
||||
speech-to-text and text-to-speech functionality using Gradium's API.
|
||||
|
||||
- Additions for `AsyncAITTSService` and `AsyncAIHttpTTSService`:
|
||||
|
||||
- Added new `languages`: `pt`, `nl`, `ar`, `ru`, `ro`, `ja`, `he`, `hy`,
|
||||
`tr`, `hi`, `zh`.
|
||||
- Updated the default model to `asyncflow_multilingual_v1.0` for improved
|
||||
accuracy and broader language coverage.
|
||||
|
||||
- Added optional tool and tool output filters for MCP services.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Deepgram logging to include Deepgram request IDs for improved
|
||||
debugging.
|
||||
|
||||
- Text Aggregation Improvements:
|
||||
|
||||
- **Breaking Change**: `BaseTextAggregator.aggregate()` now returns
|
||||
`AsyncIterator[Aggregation]` instead of `Optional[Aggregation]`. This
|
||||
enables the aggregator to return multiple results based on the provided
|
||||
text.
|
||||
- Refactored text aggregators to use inheritance: `SkipTagsAggregator` and
|
||||
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
|
||||
the base class's sentence detection logic.
|
||||
|
||||
- Improved interruption handling to prevent bots from repeating themselves. LLM
|
||||
services that return multiple sentences in a single response (e.g.,
|
||||
`GoogleLLMService`) are now split into individual sentences before being sent
|
||||
to TTS. This ensures interruptions occur at sentence boundaries, preventing
|
||||
the bot from repeating content after being interrupted during long responses.
|
||||
|
||||
- Updated `AICFilter` to use Quail STT as the default model
|
||||
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
|
||||
interaction (e.g., voice agents, speech-to-text) and operates at a native
|
||||
sample rate of 16 kHz with fixed enhancement parameters.
|
||||
|
||||
- If an unexpected exception is caught, or if `FrameProcessor.push_error()` is
|
||||
called with an exception, the file name and line number where the exception
|
||||
occured are now logged.
|
||||
|
||||
- Updated Smart Turn model weights to v3.1.
|
||||
|
||||
- Smart Turn analyzer now uses the full context of the turn rather than just
|
||||
the audio since VAD last triggered.
|
||||
|
||||
- Updated `CartesiaSTTService` to return the full transcription `result` in the
|
||||
`TranscriptionFrame` and `InterimTranscriptionFrame`. This provides access to
|
||||
word timestamp data.
|
||||
|
||||
- `HumeTTSService` changes:
|
||||
|
||||
- Added tracking headers (`X-Hume-Client-Name` and `X-Hume-Client-Version`)
|
||||
to all requests made by `HumeTTSService` to the Hume API for better usage
|
||||
tracking and analytics.
|
||||
- Added `stop()` and `cancel()` cleanup methods to `HumeTTSService` to
|
||||
properly close the HTTP client and prevent resource leaks.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- NVIDIA Services name changes (all functionality is unchanged):
|
||||
|
||||
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
|
||||
- `RivaSTTService` is now deprecated, use `NvidiaSTTService` instead.
|
||||
- `RivaTTSService` is now deprecated, use `NvidiaTTSService` instead.
|
||||
- Use `uv pip install pipecat-ai[nvidia]` instead of
|
||||
`uv pip install pipecat-ai[riva]`
|
||||
|
||||
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
|
||||
has any effect. Noise gating is now handled automatically by the AIC VAD
|
||||
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
|
||||
|
||||
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was
|
||||
always set to `us-east-1` when providing an AWS_REGION env var.
|
||||
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
|
||||
multiple times for `KEEP` or `AGGREGATE` patterns.
|
||||
|
||||
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
|
||||
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
|
||||
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was always
|
||||
set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
|
||||
spoken. Now, audio is flushed when the TTS services receives the
|
||||
`LLMFullResponseEndFrame` or `EndFrame`.
|
||||
|
||||
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
|
||||
incorrectly pushed after a functional call. This caused an issue with the
|
||||
voice-ui-kit's conversational panel rending of the LLM output after a
|
||||
function call.
|
||||
|
||||
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
|
||||
services.
|
||||
|
||||
- Fixed an issue that caused `WebsocketService` instances to attempt
|
||||
reconnection during shutdown.
|
||||
|
||||
- Fixed an issue in `ElevenLabsTTSService` where character usage metrics were
|
||||
only reported on the first TTS generation per turn.
|
||||
|
||||
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel
|
||||
|
||||
**Examples:**
|
||||
|
||||
- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py)
|
||||
- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py)
|
||||
- [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py)
|
||||
|
||||
#### Key requirements:
|
||||
|
||||
105
CONTRIBUTING.md
105
CONTRIBUTING.md
@@ -17,24 +17,121 @@ We welcome contributions of all kinds! Your help is appreciated. Follow these st
|
||||
git checkout -b your-branch-name
|
||||
```
|
||||
4. **Make your changes**: Edit or add files as necessary.
|
||||
5. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
|
||||
6. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
|
||||
5. **Add a changelog entry**: Create a changelog fragment file (see [Changelog Entries](#changelog-entries) below).
|
||||
6. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
|
||||
7. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
|
||||
|
||||
```bash
|
||||
git commit -m "Description of your changes"
|
||||
```
|
||||
|
||||
7. **Push your changes**: Push your branch to your forked repository.
|
||||
8. **Push your changes**: Push your branch to your forked repository.
|
||||
|
||||
```bash
|
||||
git push origin your-branch-name
|
||||
```
|
||||
|
||||
8. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
|
||||
9. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
|
||||
> Important: Describe the changes you've made clearly!
|
||||
|
||||
Our maintainers will review your PR, and once everything is good, your contributions will be merged!
|
||||
|
||||
## Changelog Entries
|
||||
|
||||
Every pull request that makes a user-facing change should include a changelog entry. We use a changelog fragment system to avoid merge conflicts.
|
||||
|
||||
### Creating a Changelog Fragment
|
||||
|
||||
1. Create a new file in the `changelog/` directory with this naming pattern:
|
||||
|
||||
```
|
||||
<PR_number>.<type>.md
|
||||
```
|
||||
|
||||
2. Choose the appropriate type:
|
||||
|
||||
- `added.md` - New features
|
||||
- `changed.md` - Changes in existing functionality
|
||||
- `deprecated.md` - Soon-to-be removed features
|
||||
- `removed.md` - Removed features
|
||||
- `fixed.md` - Bug fixes
|
||||
- `security.md` - Security fixes
|
||||
|
||||
3. Write your changelog entry as a Markdown bullet point. Include the `-` at the start:
|
||||
|
||||
**Example files:**
|
||||
|
||||
`changelog/1234.added.md`:
|
||||
|
||||
```markdown
|
||||
- Added support for Anthropic Claude 3.5 Sonnet with improved streaming performance.
|
||||
```
|
||||
|
||||
`changelog/5678.fixed.md`:
|
||||
|
||||
```markdown
|
||||
- Fixed an issue where audio frames were dropped during high-load scenarios.
|
||||
```
|
||||
|
||||
**For entries with nested bullets:**
|
||||
|
||||
`changelog/1234.changed.md`:
|
||||
|
||||
```markdown
|
||||
- Updated service configuration:
|
||||
|
||||
- Changed default timeout to 30 seconds
|
||||
- Added retry logic for failed connections
|
||||
```
|
||||
|
||||
### Multiple Changes in One PR
|
||||
|
||||
**Different types of changes:** Create separate fragment files for each type:
|
||||
|
||||
```
|
||||
changelog/1234.added.md
|
||||
changelog/1234.fixed.md
|
||||
```
|
||||
|
||||
**Multiple changes of the same type:** Create numbered fragment files:
|
||||
|
||||
```
|
||||
changelog/1234.changed.md
|
||||
changelog/1234.changed.2.md
|
||||
```
|
||||
|
||||
**Related changes:** Use nested bullets in a single fragment:
|
||||
|
||||
```markdown
|
||||
- Updated service configuration:
|
||||
|
||||
- Changed default timeout to 30 seconds
|
||||
- Added retry logic for failed connections
|
||||
```
|
||||
|
||||
**Rule of thumb:** One logical change per fragment file. If changes are unrelated, use separate files.
|
||||
|
||||
### Preview Your Changes
|
||||
|
||||
To see what your changelog entry will look like:
|
||||
|
||||
```bash
|
||||
towncrier build --draft --version Unreleased
|
||||
```
|
||||
|
||||
This won't modify any files, just show you a preview.
|
||||
|
||||
### When to Skip Changelog Entries
|
||||
|
||||
You can skip adding a changelog entry for:
|
||||
|
||||
- Documentation-only changes
|
||||
- Internal refactoring with no user-facing impact
|
||||
- Test-only changes
|
||||
- CI/build configuration changes
|
||||
|
||||
If you're unsure whether your change needs a changelog entry, ask in your PR!
|
||||
|
||||
## Dependency Management
|
||||
|
||||
This project uses [uv](https://docs.astral.sh/uv/) for dependency management. The `uv.lock` file is committed to ensure reproducible builds.
|
||||
|
||||
@@ -74,9 +74,9 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |
|
||||
|
||||
16
changelog/_template.md.j2
Normal file
16
changelog/_template.md.j2
Normal file
@@ -0,0 +1,16 @@
|
||||
{% for section, _ in sections.items() %}
|
||||
{% if sections[section] %}
|
||||
{% for category, val in definitions.items() if category in sections[section]%}
|
||||
### {{ definitions[category]['name'] }}
|
||||
|
||||
{% for text, values in sections[section][category].items() %}
|
||||
{{ text }}
|
||||
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
No significant changes.
|
||||
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
@@ -119,7 +119,6 @@ def import_core_modules():
|
||||
"pipecat.observers",
|
||||
"pipecat.runner",
|
||||
"pipecat.serializers",
|
||||
"pipecat.sync",
|
||||
"pipecat.transcriptions",
|
||||
"pipecat.utils",
|
||||
]
|
||||
|
||||
@@ -30,7 +30,6 @@ Quick Links
|
||||
Runner <api/pipecat.runner>
|
||||
Serializers <api/pipecat.serializers>
|
||||
Services <api/pipecat.services>
|
||||
Sync <api/pipecat.sync>
|
||||
Transcriptions <api/pipecat.transcriptions>
|
||||
Transports <api/pipecat.transports>
|
||||
Utils <api/pipecat.utils>
|
||||
Utils <api/pipecat.utils>
|
||||
|
||||
@@ -73,6 +73,9 @@ GOOGLE_CLOUD_PROJECT_ID=...
|
||||
GOOGLE_CLOUD_LOCATION=...
|
||||
GOOGLE_TEST_CREDENTIALS=...
|
||||
|
||||
# Gradium
|
||||
GRAPDIUM_API_KEY=...
|
||||
|
||||
# Grok
|
||||
GROK_API_KEY=...
|
||||
|
||||
@@ -191,4 +194,4 @@ TWILIO_AUTH_TOKEN=...
|
||||
WHATSAPP_TOKEN=...
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN=...
|
||||
WHATSAPP_PHONE_NUMBER_ID=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
|
||||
@@ -15,7 +15,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.riva.tts import FastPitchTTSService
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -36,7 +36,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
task = PipelineTask(
|
||||
Pipeline([tts, transport.output()]),
|
||||
127
examples/foundational/07af-interruptible-gradium.py
Normal file
127
examples/foundational/07af-interruptible-gradium.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.services.gradium.tts import GradiumTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
|
||||
|
||||
tts = GradiumTTSService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
voice_id="YTpq7expH9539ERJ",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -22,9 +22,9 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.riva.stt import RivaSTTService
|
||||
from pipecat.services.riva.tts import RivaTTSService
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.services.nvidia.stt import NvidiaSTTService
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -59,11 +59,13 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = RivaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
|
||||
)
|
||||
|
||||
tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -76,7 +76,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = FireworksLLMService(
|
||||
api_key=os.getenv("FIREWORKS_API_KEY"),
|
||||
model="accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
model="accounts/fireworks/models/gpt-oss-20b",
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -75,11 +75,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = NimLLMService(
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"),
|
||||
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Recommended when turning thinking off
|
||||
params=NimLLMService.InputParams(temperature=0.0),
|
||||
params=NvidiaLLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
@@ -14,20 +14,13 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
LLMRunFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
TranscriptionMessage,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
|
||||
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
|
||||
@@ -19,7 +19,6 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
|
||||
@@ -28,10 +28,10 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@@ -45,11 +45,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -46,11 +46,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -47,11 +47,11 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -64,11 +64,14 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@@ -88,6 +91,23 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# full list of tools available from rijksmuseum MCP:
|
||||
# - get_artwork_details
|
||||
# - get_artwork_image
|
||||
# - get_user_sets
|
||||
# - get_user_set_details
|
||||
# - open_image_in_browser
|
||||
# - get_artist_timeline
|
||||
|
||||
mcp_tools_filter = ["get_artwork_details", "get_artwork_image", "open_image_in_browser"]
|
||||
|
||||
|
||||
def open_image_output_filter(output: str):
|
||||
pattern = r"Successfully opened image in browser: "
|
||||
text_to_print = re.sub(pattern, "", output)
|
||||
print(f"🖼️ link to high resolution artwork: {text_to_print}")
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -136,7 +156,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
),
|
||||
# Optional
|
||||
tools_filter=mcp_tools_filter, # Optional
|
||||
tools_output_filters={"open_image_in_browser": open_image_output_filter},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
|
||||
@@ -67,13 +67,14 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
return None
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except:
|
||||
pass
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
try:
|
||||
|
||||
@@ -63,6 +63,7 @@ fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ]
|
||||
gradium = [ "pipecat-ai[websockets-base]" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
@@ -83,8 +84,8 @@ mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
nim = []
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
nvidia = [ "nvidia-riva-client~=2.21.1" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
openrouter = []
|
||||
@@ -93,7 +94,7 @@ playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
remote-smart-turn = []
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
riva = [ "pipecat-ai[nvidia]" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
|
||||
sambanova = []
|
||||
@@ -129,6 +130,7 @@ dev = [
|
||||
"setuptools~=78.1.1",
|
||||
"setuptools_scm~=8.3.1",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"towncrier~=25.8.0",
|
||||
]
|
||||
|
||||
docs = [
|
||||
@@ -159,7 +161,7 @@ where = ["src"]
|
||||
"src/pipecat/audio/dtmf/dtmf-star.wav",
|
||||
]
|
||||
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.1-cpu.onnx"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
@@ -206,3 +208,44 @@ convention = "google"
|
||||
command_line = "--module pytest"
|
||||
source = ["src"]
|
||||
omit = ["*/tests/*"]
|
||||
|
||||
[tool.towncrier]
|
||||
package = "pipecat"
|
||||
package_dir = "src"
|
||||
filename = "CHANGELOG.md"
|
||||
directory = "changelog"
|
||||
start_string = "<!-- towncrier release notes start -->\n"
|
||||
template = "changelog/_template.md.j2"
|
||||
title_format = "## [{version}] - {project_date}"
|
||||
underlines = ["", "", ""]
|
||||
wrap = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "added"
|
||||
name = "Added"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "changed"
|
||||
name = "Changed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "deprecated"
|
||||
name = "Deprecated"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "removed"
|
||||
name = "Removed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "fixed"
|
||||
name = "Fixed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "security"
|
||||
name = "Security"
|
||||
showcontent = true
|
||||
|
||||
@@ -103,7 +103,7 @@ TESTS_07 = [
|
||||
("07o-interruptible-assemblyai.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime-http.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-riva-nim.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-nvidia.py", EVAL_SIMPLE_MATH),
|
||||
("07s-interruptible-google-audio-in.py", EVAL_SIMPLE_MATH),
|
||||
("07t-interruptible-fish.py", EVAL_SIMPLE_MATH),
|
||||
("07v-interruptible-neuphonic.py", EVAL_SIMPLE_MATH),
|
||||
@@ -136,7 +136,7 @@ TESTS_14 = [
|
||||
("14g-function-calling-grok.py", EVAL_WEATHER),
|
||||
("14h-function-calling-azure.py", EVAL_WEATHER),
|
||||
("14i-function-calling-fireworks.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nim.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nvidia.py", EVAL_WEATHER),
|
||||
("14k-function-calling-cerebras.py", EVAL_WEATHER),
|
||||
("14m-function-calling-openrouter.py", EVAL_WEATHER),
|
||||
("14n-function-calling-perplexity.py", EVAL_WEATHER),
|
||||
|
||||
@@ -28,7 +28,6 @@ from pipecat.metrics.metrics import MetricsData, SmartTurnMetricsData
|
||||
STOP_SECS = 3
|
||||
PRE_SPEECH_MS = 0
|
||||
MAX_DURATION_SECONDS = 8 # Max allowed segment duration
|
||||
USE_ONLY_LAST_VAD_SEGMENT = True
|
||||
|
||||
|
||||
class SmartTurnParams(BaseTurnParams):
|
||||
@@ -43,8 +42,6 @@ class SmartTurnParams(BaseTurnParams):
|
||||
stop_secs: float = STOP_SECS
|
||||
pre_speech_ms: float = PRE_SPEECH_MS
|
||||
max_duration_secs: float = MAX_DURATION_SECONDS
|
||||
# not exposing this for now yet until the model can handle it.
|
||||
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
|
||||
|
||||
|
||||
class SmartTurnTimeoutException(Exception):
|
||||
@@ -160,7 +157,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
state, result = await loop.run_in_executor(
|
||||
self._executor, self._process_speech_segment, self._audio_buffer
|
||||
)
|
||||
if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT:
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
self._clear(state)
|
||||
logger.debug(f"End of Turn result: {state}")
|
||||
return state, result
|
||||
|
||||
Binary file not shown.
@@ -42,17 +42,15 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
Args:
|
||||
smart_turn_model_path: Path to the ONNX model file. If this is not
|
||||
set, the bundled smart-turn-v3.0 model will be used.
|
||||
set, the bundled smart-turn-v3.1-cpu model will be used.
|
||||
cpu_count: The number of CPUs to use for inference. Defaults to 1.
|
||||
**kwargs: Additional arguments passed to BaseSmartTurn.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
logger.debug("Loading Local Smart Turn v3 model...")
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.0.onnx"
|
||||
model_name = "smart-turn-v3.1-cpu.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
@@ -70,6 +68,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
impresources.files(package_path).joinpath(model_name)
|
||||
)
|
||||
|
||||
logger.debug(f"Loading Local Smart Turn v3.x model from {smart_turn_model_path}...")
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
@@ -79,7 +79,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3")
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
@@ -18,8 +18,10 @@ from loguru import logger
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMTextFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -149,11 +151,18 @@ class IVRProcessor(FrameProcessor):
|
||||
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
# Process text through the pattern aggregator
|
||||
result = await self._aggregator.aggregate(frame.text)
|
||||
if result:
|
||||
async for result in self._aggregator.aggregate(frame.text):
|
||||
# Push aggregated text that doesn't contain XML patterns
|
||||
await self.push_frame(LLMTextFrame(result.text), direction)
|
||||
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
# Flush any remaining text from the aggregator
|
||||
remaining = await self._aggregator.flush()
|
||||
if remaining:
|
||||
await self.push_frame(LLMTextFrame(remaining.text), direction)
|
||||
# Push the end frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -40,8 +40,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
|
||||
class NotifierGate(FrameProcessor):
|
||||
|
||||
@@ -330,7 +330,7 @@ class TextFrame(DataFrame):
|
||||
"""
|
||||
|
||||
text: str
|
||||
skip_tts: bool = field(init=False)
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
# Whether any necessary inter-frame (leading/trailing) spaces are already
|
||||
# included in the text.
|
||||
# NOTE: Ideally this would be available at init time with a default value,
|
||||
@@ -343,7 +343,7 @@ class TextFrame(DataFrame):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.skip_tts = None
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
@@ -1632,22 +1632,22 @@ class LLMFullResponseStartFrame(ControlFrame):
|
||||
more TextFrames and a final LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
skip_tts: bool = field(init=False)
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.skip_tts = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseEndFrame(ControlFrame):
|
||||
"""Frame indicating the end of an LLM response."""
|
||||
|
||||
skip_tts: bool = field(init=False)
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.skip_tts = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class GatedLLMContextAggregator(FrameProcessor):
|
||||
|
||||
@@ -83,8 +83,7 @@ class LLMTextProcessor(FrameProcessor):
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
async for aggregation in self._text_aggregator.aggregate(in_frame.text):
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
@@ -92,15 +91,13 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
async def _handle_llm_end(self, skip_tts: Optional[bool] = None):
|
||||
# Flush any remaining text
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
text=remaining.text,
|
||||
aggregated_by=remaining.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Awaitable, Callable, Tuple, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class WakeNotifierFilter(FrameProcessor):
|
||||
|
||||
@@ -12,6 +12,7 @@ management, and frame flow control mechanisms.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Sequence, Tuple, Type
|
||||
@@ -677,7 +678,17 @@ class FrameProcessor(BaseObject):
|
||||
if not error.processor:
|
||||
error.processor = self
|
||||
await self._call_event_handler("on_error", error)
|
||||
logger.error(f"{error.processor} error: {error.error}")
|
||||
|
||||
if error.exception:
|
||||
tb = traceback.extract_tb(error.exception.__traceback__)
|
||||
last = tb[-1]
|
||||
error_message = (
|
||||
f"{error.processor} exception ({last.filename}:{last.lineno}): {error.error}"
|
||||
)
|
||||
else:
|
||||
error_message = f"{error.processor} error: {error.error}"
|
||||
|
||||
logger.error(error_message)
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
|
||||
@@ -935,8 +935,8 @@ class RTVIObserverParams:
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
bot_output_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
|
||||
@@ -56,6 +56,17 @@ def language_to_async_language(language: Language) -> Optional[str]:
|
||||
Language.ES: "es",
|
||||
Language.DE: "de",
|
||||
Language.IT: "it",
|
||||
Language.PT: "pt",
|
||||
Language.NL: "nl",
|
||||
Language.AR: "ar",
|
||||
Language.RU: "ru",
|
||||
Language.RO: "ro",
|
||||
Language.JA: "ja",
|
||||
Language.HE: "he",
|
||||
Language.HY: "hy",
|
||||
Language.TR: "tr",
|
||||
Language.HI: "hi",
|
||||
Language.ZH: "zh",
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
@@ -74,7 +85,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
language: Language to use for synthesis.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
language: Optional[Language] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -83,7 +94,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
voice_id: str,
|
||||
version: str = "v1",
|
||||
url: str = "wss://api.async.ai/text_to_speech/websocket/ws",
|
||||
model: str = "asyncflow_v2.0",
|
||||
model: str = "asyncflow_multilingual_v1.0",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "pcm_s16le",
|
||||
container: str = "raw",
|
||||
@@ -99,7 +110,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
https://docs.async.ai/list-voices-16699698e0
|
||||
version: Async API version.
|
||||
url: WebSocket URL for Async TTS API.
|
||||
model: TTS model to use (e.g., "asyncflow_v2.0").
|
||||
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
|
||||
sample_rate: Audio sample rate.
|
||||
encoding: Audio encoding format.
|
||||
container: Audio container format.
|
||||
@@ -128,7 +139,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en",
|
||||
else None,
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
@@ -357,7 +368,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
language: Language to use for synthesis.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
language: Optional[Language] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -365,7 +376,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
model: str = "asyncflow_v2.0",
|
||||
model: str = "asyncflow_multilingual_v1.0",
|
||||
url: str = "https://api.async.ai",
|
||||
version: str = "v1",
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -380,7 +391,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
api_key: Async API key.
|
||||
voice_id: ID of the voice to use for synthesis.
|
||||
aiohttp_session: An aiohttp session for making HTTP requests.
|
||||
model: TTS model to use (e.g., "asyncflow_v2.0").
|
||||
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
|
||||
url: Base URL for Async API.
|
||||
version: API version string for Async API.
|
||||
sample_rate: Audio sample rate.
|
||||
@@ -404,7 +415,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en",
|
||||
else None,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self.set_model_name(model)
|
||||
|
||||
@@ -20,7 +20,6 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -349,6 +348,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
@@ -361,5 +361,6 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -244,6 +244,11 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
# Creating the receiver task
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(
|
||||
|
||||
@@ -234,6 +234,13 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
if not await self._connection.start(options=self._settings, addons=self._addons):
|
||||
await self.push_error(error_msg=f"Unable to connect to Deepgram")
|
||||
else:
|
||||
headers = {
|
||||
k: v
|
||||
for k, v in self._connection._socket.response.headers.items()
|
||||
if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
async def _disconnect(self):
|
||||
if await self._connection.is_connected():
|
||||
|
||||
@@ -71,7 +71,12 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService class.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
@@ -165,6 +170,11 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
@@ -231,7 +241,6 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
logger.trace(f"Received Flushed: {msg}")
|
||||
# Flushed indicates the end of audio generation for the current buffer
|
||||
# This happens after flush_audio() is called
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {msg}")
|
||||
# Buffer has been cleared after interruption
|
||||
@@ -286,7 +295,7 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
speak_msg = {"type": "Speak", "text": text}
|
||||
await self._get_websocket().send(json.dumps(speak_msg))
|
||||
|
||||
# The actual audio frames will be handled in _receive_messages
|
||||
# The audio frames will be handled in _receive_messages
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -160,7 +160,7 @@ def build_elevenlabs_voice_settings(
|
||||
class PronunciationDictionaryLocator(BaseModel):
|
||||
"""Locator for a pronunciation dictionary.
|
||||
|
||||
Attributes:
|
||||
Parameters:
|
||||
pronunciation_dictionary_id: The ID of the pronunciation dictionary.
|
||||
version_id: The version ID of the pronunciation dictionary.
|
||||
"""
|
||||
@@ -731,10 +731,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
else:
|
||||
await self._send_text(text)
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield TTSStoppedFrame()
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
5
src/pipecat/services/gradium/__init__.py
Normal file
5
src/pipecat/services/gradium/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
239
src/pipecat/services/gradium/stt.py
Normal file
239
src/pipecat/services/gradium/stt.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gradium's speech-to-text service implementation.
|
||||
|
||||
This module provides integration with Gradium's real-time speech-to-text
|
||||
WebSocket API for streaming audio transcription.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
"""Gradium real-time speech-to-text service.
|
||||
|
||||
Provides real-time speech transcription using Gradium's WebSocket API.
|
||||
Supports both interim and final transcriptions with configurable parameters
|
||||
for audio processing and connection management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
json_config: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._json_config = json_config
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
self._audio_buffer = bytearray()
|
||||
self._chunk_size_ms = 80
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True if metrics generation is supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Start frame to begin processing.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: End frame to stop processing.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Cancel frame to abort processing.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text conversion.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to process.
|
||||
|
||||
Yields:
|
||||
None (processing handled via WebSocket messages).
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
chunk = base64.b64encode(chunk).decode("utf-8")
|
||||
msg = {"type": "audio", "audio": chunk}
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
yield None
|
||||
|
||||
@traced_stt
|
||||
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
|
||||
"""Record transcription event for tracing."""
|
||||
pass
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
ws_url = self._api_endpoint_base_url
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
"x-api-source": "pipecat",
|
||||
}
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"input_format": "pcm",
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Gradium STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
await self._process_messages()
|
||||
logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting")
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _process_response(self, msg):
|
||||
type_ = msg.get("type", "")
|
||||
if type_ == "text":
|
||||
await self._handle_text(msg["text"])
|
||||
elif type_ == "end_of_stream":
|
||||
await self._handle_end_of_stream()
|
||||
elif type_ == "error":
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
|
||||
async def _handle_end_of_stream(self):
|
||||
"""Handle termination message."""
|
||||
logger.debug("Received end_of_stream message from server")
|
||||
|
||||
async def _handle_text(self, text: str):
|
||||
"""Handle transcription results."""
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
315
src/pipecat/services/gradium/tts.py
Normal file
315
src/pipecat/services/gradium/tts.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
|
||||
"""Gradium Text-to-Speech service implementation."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets import ConnectionClosedOK
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Gradium, you need to `pip install pipecat-ai[gradium]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 48000
|
||||
|
||||
|
||||
class GradiumTTSService(InterruptibleWordTTSService):
|
||||
"""Text-to-Speech service using Gradium's websocket API."""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Gradium TTS service.
|
||||
|
||||
Parameters:
|
||||
temp: Temperature to be used for generation, defaults to 0.6.
|
||||
"""
|
||||
|
||||
temp: Optional[float] = 0.6
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "YTpq7expH9539ERJ",
|
||||
url: str = "wss://eu.api.gradium.ai/api/speech/tts",
|
||||
model: str = "default",
|
||||
json_config: Optional[str] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
voice_id: the voice identifier.
|
||||
url: Gradium websocket API endpoint.
|
||||
model: Model ID to use for synthesis.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
params: Additional configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = params or GradiumTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._json_config = json_config
|
||||
self._model = model
|
||||
self._settings = {
|
||||
"voice_id": voice_id,
|
||||
"model_name": model,
|
||||
"output_format": "pcm",
|
||||
}
|
||||
|
||||
# State tracking
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Gradium service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Update the TTS model.
|
||||
|
||||
Args:
|
||||
model: The model name to use for synthesis.
|
||||
"""
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
self._settings["voice_id"] = self._voice_id
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Gradium API."""
|
||||
return {"text": text, "type": "text"}
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish websocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close connection.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel current operation and clean up.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish websocket connection and start receive task."""
|
||||
logger.debug(f"{self}: connecting")
|
||||
|
||||
# If the server disconnected, cancel the receive-task so that it can be reset below.
|
||||
if self._websocket is None or self._websocket.state is not State.OPEN:
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
logger.debug(f"{self}: setting receive task")
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
logger.debug(f"{self}: disconnecting")
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Gradium websocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
headers = {"x-api-key": self._api_key, "x-api-source": "pipecat"}
|
||||
self._websocket = await websocket_connect(self._url, additional_headers=headers)
|
||||
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"output_format": "pcm",
|
||||
"voice_id": self._voice_id,
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close websocket connection and reset state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
# TODO(laurent): This should not be necessary as it should happen when
|
||||
# receiving the messages but this does not seem to always be the case
|
||||
# and that may lead to a busy polling loop.
|
||||
if self._websocket and self._websocket.state is State.CLOSED:
|
||||
raise ConnectionClosedOK(None, None)
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
if msg["type"] == "audio":
|
||||
# Process audio chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
elif msg["type"] == "text":
|
||||
await self.add_word_timestamps([(msg["text"], msg["start_s"])])
|
||||
elif msg["type"] == "end_of_stream":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
|
||||
elif msg["type"] == "error":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Gradium's streaming API.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
_state = self._websocket.state if self._websocket is not None else None
|
||||
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
self._websocket = None
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
yield TTSStartedFrame()
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -8,10 +8,14 @@ import base64
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat import __version__
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
@@ -26,11 +30,7 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from hume import AsyncHumeClient
|
||||
from hume.tts import (
|
||||
FormatPcm,
|
||||
PostedUtterance,
|
||||
PostedUtteranceVoiceWithId,
|
||||
)
|
||||
from hume.tts import FormatPcm, PostedUtterance, PostedUtteranceVoiceWithId
|
||||
from hume.tts.types import TimestampMessage
|
||||
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -40,6 +40,12 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
|
||||
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
|
||||
|
||||
# Tracking headers for Hume API requests
|
||||
DEFAULT_HEADERS = {
|
||||
"X-Hume-Client-Name": "pipecat",
|
||||
"X-Hume-Client-Version": __version__,
|
||||
}
|
||||
|
||||
|
||||
class HumeTTSService(WordTTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
@@ -104,7 +110,11 @@ class HumeTTSService(WordTTSService):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key)
|
||||
# Create a custom httpx.AsyncClient with tracking headers
|
||||
# Headers are included in all requests made by the Hume SDK
|
||||
self._http_client = httpx.AsyncClient(headers=DEFAULT_HEADERS)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key, httpx_client=self._http_client)
|
||||
self._params = params or HumeTTSService.InputParams()
|
||||
|
||||
# Store voice in the base class (mirrors other services)
|
||||
@@ -138,6 +148,26 @@ class HumeTTSService(WordTTSService):
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
"""Stop the service and cleanup resources.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if hasattr(self, "_http_client") and self._http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
"""Cancel the service and cleanup resources.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
if hasattr(self, "_http_client") and self._http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
|
||||
@@ -173,16 +173,17 @@ class LLMService(AIService):
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
self._skip_tts: bool = False
|
||||
self._skip_tts: Optional[bool] = None
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -293,7 +294,8 @@ class LLMService(AIService):
|
||||
direction: The direction of frame pushing.
|
||||
"""
|
||||
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
frame.skip_tts = self._skip_tts
|
||||
if self._skip_tts is not None:
|
||||
frame.skip_tts = self._skip_tts
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@@ -435,6 +437,7 @@ class LLMService(AIService):
|
||||
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
|
||||
|
||||
runner_items = []
|
||||
for function_call in function_calls:
|
||||
if function_call.function_name in self._functions.keys():
|
||||
item = self._functions[function_call.function_name]
|
||||
@@ -446,28 +449,20 @@ class LLMService(AIService):
|
||||
)
|
||||
continue
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
runner_items.append(
|
||||
FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
)
|
||||
)
|
||||
|
||||
if self._run_in_parallel:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
else:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
if self._run_in_parallel:
|
||||
await self._run_parallel_function_calls(runner_items)
|
||||
else:
|
||||
await self._run_sequential_function_calls(runner_items)
|
||||
|
||||
async def request_image_frame(
|
||||
self,
|
||||
@@ -540,6 +535,27 @@ class LLMService(AIService):
|
||||
await task
|
||||
del self._function_call_tasks[task]
|
||||
|
||||
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
tasks = []
|
||||
for runner_item in runner_items:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
tasks.append(task)
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
|
||||
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
# Enqueue all function calls for background execution.
|
||||
for runner_item in runner_items:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
@@ -623,20 +639,19 @@ class LLMService(AIService):
|
||||
name = runner_item.function_name
|
||||
tool_call_id = runner_item.tool_call_id
|
||||
|
||||
# We remove the callback because we are going to cancel the task
|
||||
# now, otherwise we will be removing it from the set while we
|
||||
# are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
|
||||
await self.cancel_task(task)
|
||||
if task:
|
||||
# We remove the callback because we are going to cancel the
|
||||
# task next, otherwise we will be removing it from the set
|
||||
# while we are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
await self.cancel_task(task)
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, TypeAlias
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -46,17 +46,24 @@ class MCPClient(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
tools_filter: Optional[List[str]] = None,
|
||||
tools_output_filters: Optional[Dict[str, Callable[[Any], Any]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
Args:
|
||||
server_params: Server connection parameters (stdio or SSE).
|
||||
tools_filter: Optional list of tool names to register. If None, all tools are registered.
|
||||
tools_output_filters: Optional dict mapping tool names to filter functions that process tool outputs.
|
||||
Each filter function receives the raw tool output (any type) and returns the processed output (any type).
|
||||
**kwargs: Additional arguments passed to the parent BaseObject.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._server_params = server_params
|
||||
self._session = ClientSession
|
||||
self._tools_filter = tools_filter
|
||||
self._tools_output_filters = tools_output_filters or {}
|
||||
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
self._client = stdio_client
|
||||
@@ -264,13 +271,26 @@ class MCPClient(BaseObject):
|
||||
else:
|
||||
# logger.debug(f"Non-text result content: '{content}'")
|
||||
pass
|
||||
logger.info(f"Tool '{function_name}' completed successfully")
|
||||
logger.debug(f"Final response: {response}")
|
||||
else:
|
||||
logger.error(f"Error getting content from {function_name} results.")
|
||||
|
||||
final_response = response if len(response) else "Sorry, could not call the mcp tool"
|
||||
await result_callback(final_response)
|
||||
# Apply output filter if configured for this tool
|
||||
if function_name in self._tools_output_filters:
|
||||
try:
|
||||
response = self._tools_output_filters[function_name](response)
|
||||
logger.debug(f"Final response (after filter): {response}")
|
||||
|
||||
except Exception:
|
||||
logger.error(f"Error applying output filter for {function_name}")
|
||||
response = ""
|
||||
|
||||
if response and len(response) and isinstance(response, str):
|
||||
logger.info(f"Tool '{function_name}' completed successfully")
|
||||
logger.debug(f"Final response: {response}")
|
||||
else:
|
||||
response = "Sorry, could not call the mcp tool"
|
||||
|
||||
await result_callback(response)
|
||||
|
||||
async def _list_tools_helper(self, session):
|
||||
available_tools = await session.list_tools()
|
||||
@@ -283,6 +303,12 @@ class MCPClient(BaseObject):
|
||||
|
||||
for tool in available_tools.tools:
|
||||
tool_name = tool.name
|
||||
|
||||
# Apply tools filter if configured
|
||||
if self._tools_filter and tool_name not in self._tools_filter:
|
||||
logger.debug(f"Skipping tool '{tool_name}' - not in allowed tools list")
|
||||
continue
|
||||
|
||||
logger.debug(f"Processing tool: {tool_name}")
|
||||
logger.debug(f"Tool description: {tool.description}")
|
||||
|
||||
|
||||
@@ -8,98 +8,23 @@
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaLLMService from
|
||||
pipecat.services.nvidia.llm instead.
|
||||
"""
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
import warnings
|
||||
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
|
||||
class NimLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"NimLLMService from pipecat.services.nim.llm is deprecated. "
|
||||
"Please use NvidiaLLMService from pipecat.services.nvidia.llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NimLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
NimLLMService = NvidiaLLMService
|
||||
|
||||
0
src/pipecat/services/nvidia/__init__.py
Normal file
0
src/pipecat/services/nvidia/__init__.py
Normal file
105
src/pipecat/services/nvidia/llm.py
Normal file
105
src/pipecat/services/nvidia/llm.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA NIM API service implementation.
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
"""
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class NvidiaLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NvidiaLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
663
src/pipecat/services/nvidia/stt.py
Normal file
663
src/pipecat/services/nvidia/stt.py
Normal file
@@ -0,0 +1,663 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_nvidia_riva_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to NVIDIA Riva ASR language codes.
|
||||
|
||||
Source:
|
||||
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
# Arabic
|
||||
Language.AR: "ar-AR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_GB: "en-GB",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Portuguese
|
||||
Language.PT: "pt-BR", # Default to Brazilian
|
||||
Language.PT_BR: "pt-BR",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US", # US Spanish
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class NvidiaSTTService(STTService):
|
||||
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
|
||||
|
||||
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
|
||||
through streaming recognition. Supports interim results and continuous audio
|
||||
processing for low-latency applications.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: NVIDIA Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for NVIDIA Riva.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._settings = {
|
||||
"language": str(params.language),
|
||||
"profanity_filter": self._profanity_filter,
|
||||
"automatic_punctuation": self._automatic_punctuation,
|
||||
"verbatim_transcripts": not self._no_verbatim_transcripts,
|
||||
"boosted_lm_words": self._boosted_lm_words,
|
||||
"boosted_lm_score": self._boosted_lm_score,
|
||||
}
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
False - this service does not support metrics generation.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the NVIDIA Riva STT service and initialize streaming configuration.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the NVIDIA Riva STT service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the NVIDIA Riva STT service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._thread_task:
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
self._thread_running = True
|
||||
await asyncio.to_thread(self._response_handler)
|
||||
except asyncio.CancelledError:
|
||||
self._thread_running = False
|
||||
raise
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_response(self, response):
|
||||
for result in response.results:
|
||||
if result and not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=result.is_final,
|
||||
language=self._language_code,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
"""Get the next audio chunk for NVIDIA Riva processing.
|
||||
|
||||
Returns:
|
||||
Audio bytes from the queue.
|
||||
|
||||
Raises:
|
||||
StopIteration: When the thread is no longer running.
|
||||
"""
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
except FuturesCancelledError:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator for audio chunk processing.
|
||||
|
||||
Returns:
|
||||
Self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
|
||||
|
||||
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
|
||||
transcription on audio segments. It inherits from SegmentedSTTService to handle
|
||||
audio buffering and speech detection.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva segmented STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
profanity_filter: Whether to filter profanity from results.
|
||||
automatic_punctuation: Whether to add automatic punctuation.
|
||||
verbatim_transcripts: Whether to return verbatim transcripts.
|
||||
boosted_lm_words: List of words to boost in language model.
|
||||
boosted_lm_score: Score boost for specified words.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
profanity_filter: bool = False
|
||||
automatic_punctuation: bool = True
|
||||
verbatim_transcripts: bool = False
|
||||
boosted_lm_words: Optional[List[str]] = None
|
||||
boosted_lm_score: float = 4.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
"model_name": "canary-1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva segmented STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication
|
||||
server: NVIDIA Riva server address (defaults to NVIDIA Cloud Function endpoint)
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for NVIDIA Riva
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSegmentedSTTService.InputParams()
|
||||
|
||||
# Set model name
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
# Initialize NVIDIA Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
# Store the language as a Language enum and as a string
|
||||
self._language_enum = params.language or Language.EN_US
|
||||
self._language = self.language_to_service_language(self._language_enum) or "en-US"
|
||||
|
||||
# Configure transcription parameters
|
||||
self._profanity_filter = params.profanity_filter
|
||||
self._automatic_punctuation = params.automatic_punctuation
|
||||
self._verbatim_transcripts = params.verbatim_transcripts
|
||||
self._boosted_lm_words = params.boosted_lm_words
|
||||
self._boosted_lm_score = params.boosted_lm_score
|
||||
|
||||
# Voice activity detection thresholds (use NVIDIA Riva defaults)
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
# Create NVIDIA Riva client
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = {"language": self._language_enum}
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to NVIDIA Riva's language code.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
return language_to_nvidia_riva_language(language)
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the NVIDIA Riva ASR client with authentication metadata."""
|
||||
if self._asr_service is not None:
|
||||
return
|
||||
|
||||
# Set up authentication metadata for NVIDIA Cloud Functions
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=self._language, # Now using the string, not a tuple
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=self._verbatim_transcripts,
|
||||
)
|
||||
|
||||
# Add word boosting if specified
|
||||
if self._boosted_lm_words:
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
# Add voice activity detection parameters
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
|
||||
# Add any custom configuration
|
||||
if self._custom_configuration:
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True - this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the service when the pipeline starts.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
|
||||
Args:
|
||||
language: Target language for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_enum = language
|
||||
self._language = self.language_to_service_language(language) or "en-US"
|
||||
self._settings["language"] = language
|
||||
|
||||
# Update configuration with new language
|
||||
if self._config:
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in NVIDIA Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
|
||||
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
187
src/pipecat/services/nvidia/tts.py
Normal file
187
src/pipecat/services/nvidia/tts.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva text-to-speech service implementation.
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
NVIDIA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class NvidiaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language code for synthesis. Defaults to US English.
|
||||
quality: Audio quality setting (0-100). Defaults to 20.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
Note: Model cannot be changed after initialization for Riva service.
|
||||
|
||||
Args:
|
||||
model: The model name to set (operation not supported).
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
@@ -133,6 +133,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
self._retry_timeout_secs = retry_timeout_secs
|
||||
self._retry_on_timeout = retry_on_timeout
|
||||
self.set_model_name(model)
|
||||
self._full_model_name: str = ""
|
||||
self._client = self.create_client(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
@@ -185,6 +186,22 @@ class BaseOpenAILLMService(LLMService):
|
||||
"""
|
||||
return True
|
||||
|
||||
def set_full_model_name(self, full_model_name: str):
|
||||
"""Set the full AI model name.
|
||||
|
||||
Args:
|
||||
full_model_name: The full name of the AI model to use.
|
||||
"""
|
||||
self._full_model_name = full_model_name
|
||||
|
||||
def get_full_model_name(self):
|
||||
"""Get the current full model name.
|
||||
|
||||
Returns:
|
||||
The full name of the AI model being used.
|
||||
"""
|
||||
return self._full_model_name
|
||||
|
||||
async def get_chat_completions(
|
||||
self, params_from_context: OpenAILLMInvocationParams
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
@@ -360,6 +377,9 @@ class BaseOpenAILLMService(LLMService):
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.model and self.get_full_model_name() != chunk.model:
|
||||
self.set_full_model_name(chunk.model)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
|
||||
@@ -4,707 +4,32 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaSTTService from
|
||||
pipecat.services.nvidia.stt instead.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
import warnings
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
from pipecat.services.nvidia.stt import (
|
||||
NvidiaSegmentedSTTService,
|
||||
NvidiaSTTService,
|
||||
language_to_nvidia_riva_language,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_riva_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to Riva ASR language codes.
|
||||
|
||||
Source:
|
||||
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Riva language code or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
# Arabic
|
||||
Language.AR: "ar-AR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_GB: "en-GB",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Portuguese
|
||||
Language.PT: "pt-BR", # Default to Brazilian
|
||||
Language.PT_BR: "pt-BR",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US", # US Spanish
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class RivaSTTService(STTService):
|
||||
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
|
||||
|
||||
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
|
||||
through streaming recognition. Supports interim results and continuous audio
|
||||
processing for low-latency applications.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Riva STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Riva STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for Riva.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or RivaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._settings = {
|
||||
"language": str(params.language),
|
||||
"profanity_filter": self._profanity_filter,
|
||||
"automatic_punctuation": self._automatic_punctuation,
|
||||
"verbatim_transcripts": not self._no_verbatim_transcripts,
|
||||
"boosted_lm_words": self._boosted_lm_words,
|
||||
"boosted_lm_score": self._boosted_lm_score,
|
||||
}
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
False - this service does not support metrics generation.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Riva STT service and initialize streaming configuration.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Riva STT service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Riva STT service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._thread_task:
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
self._thread_running = True
|
||||
await asyncio.to_thread(self._response_handler)
|
||||
except asyncio.CancelledError:
|
||||
self._thread_running = False
|
||||
raise
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_response(self, response):
|
||||
for result in response.results:
|
||||
if result and not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=result.is_final,
|
||||
language=self._language_code,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
"""Get the next audio chunk for Riva processing.
|
||||
|
||||
Returns:
|
||||
Audio bytes from the queue.
|
||||
|
||||
Raises:
|
||||
StopIteration: When the thread is no longer running.
|
||||
"""
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
except FuturesCancelledError:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator for audio chunk processing.
|
||||
|
||||
Returns:
|
||||
Self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class RivaSegmentedSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
|
||||
|
||||
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
|
||||
transcription on audio segments. It inherits from SegmentedSTTService to handle
|
||||
audio buffering and speech detection.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Riva segmented STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
profanity_filter: Whether to filter profanity from results.
|
||||
automatic_punctuation: Whether to add automatic punctuation.
|
||||
verbatim_transcripts: Whether to return verbatim transcripts.
|
||||
boosted_lm_words: List of words to boost in language model.
|
||||
boosted_lm_score: Score boost for specified words.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
profanity_filter: bool = False
|
||||
automatic_punctuation: bool = True
|
||||
verbatim_transcripts: bool = False
|
||||
boosted_lm_words: Optional[List[str]] = None
|
||||
boosted_lm_score: float = 4.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
"model_name": "canary-1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Riva segmented STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication
|
||||
server: Riva server address (defaults to NVIDIA Cloud Function endpoint)
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for Riva
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or RivaSegmentedSTTService.InputParams()
|
||||
|
||||
# Set model name
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
# Initialize Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
# Store the language as a Language enum and as a string
|
||||
self._language_enum = params.language or Language.EN_US
|
||||
self._language = self.language_to_service_language(self._language_enum) or "en-US"
|
||||
|
||||
# Configure transcription parameters
|
||||
self._profanity_filter = params.profanity_filter
|
||||
self._automatic_punctuation = params.automatic_punctuation
|
||||
self._verbatim_transcripts = params.verbatim_transcripts
|
||||
self._boosted_lm_words = params.boosted_lm_words
|
||||
self._boosted_lm_score = params.boosted_lm_score
|
||||
|
||||
# Voice activity detection thresholds (use Riva defaults)
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
# Create Riva client
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = {"language": self._language_enum}
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to Riva's language code.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Riva language code or None if not supported.
|
||||
"""
|
||||
return language_to_riva_language(language)
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the Riva ASR client with authentication metadata."""
|
||||
if self._asr_service is not None:
|
||||
return
|
||||
|
||||
# Set up authentication metadata for NVIDIA Cloud Functions
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized RivaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=self._language, # Now using the string, not a tuple
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=self._verbatim_transcripts,
|
||||
)
|
||||
|
||||
# Add word boosting if specified
|
||||
if self._boosted_lm_words:
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
# Add voice activity detection parameters
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
|
||||
# Add any custom configuration
|
||||
if self._custom_configuration:
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True - this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the service when the pipeline starts.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
|
||||
Args:
|
||||
language: Target language for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_enum = language
|
||||
self._language = self.language_to_service_language(language) or "en-US"
|
||||
self._settings["language"] = language
|
||||
|
||||
# Update configuration with new language
|
||||
if self._config:
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
|
||||
class ParakeetSTTService(RivaSTTService):
|
||||
"""Deprecated speech-to-text service using NVIDIA Parakeet models.
|
||||
|
||||
.. deprecated:: 0.0.66
|
||||
This class is deprecated. Use `RivaSTTService` instead for equivalent functionality
|
||||
with Parakeet models by specifying the appropriate model_function_map.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[RivaSTTService.InputParams] = None, # Use parent class's type
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Parakeet STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for Parakeet model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for Riva.
|
||||
**kwargs: Additional arguments passed to RivaSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
server=server,
|
||||
model_function_map=model_function_map,
|
||||
sample_rate=sample_rate,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"RivaSTTService and ParakeetSTTService "
|
||||
"from pipecat.services.riva.stt is deprecated. "
|
||||
"Please use NvidiaSTTService from pipecat.services.nvidia.stt instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
RivaSTTService = NvidiaSTTService
|
||||
language_to_riva_language = language_to_nvidia_riva_language
|
||||
RivaSegmentedSTTService = NvidiaSegmentedSTTService
|
||||
ParakeetSTTService = NvidiaSTTService
|
||||
|
||||
@@ -8,231 +8,26 @@
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaTTSService from
|
||||
pipecat.services.nvidia.tts instead.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
from pipecat.services.nvidia.tts import NVIDIA_TTS_TIMEOUT_SECS, NvidiaTTSService
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"FastPitchTTSService and RivaTTSService "
|
||||
"from pipecat.services.nim.llm are deprecated. "
|
||||
"Please use NvidiaLLMService from pipecat.services.nvidia.tts instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
RIVA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class RivaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language code for synthesis. Defaults to US English.
|
||||
quality: Audio quality setting (0-100). Defaults to 20.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or RivaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
Note: Model cannot be changed after initialization for Riva service.
|
||||
|
||||
Args:
|
||||
model: The model name to set (operation not supported).
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class FastPitchTTSService(RivaTTSService):
|
||||
"""Deprecated FastPitch TTS service.
|
||||
|
||||
.. deprecated:: 0.0.66
|
||||
This class is deprecated. Use RivaTTSService instead for new implementations.
|
||||
Provides backward compatibility for existing FastPitch TTS integrations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "English-US.Female-1",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "0149dedb-2be8-4195-b9a0-e57e0e14f972",
|
||||
"model_name": "fastpitch-hifigan-tts",
|
||||
},
|
||||
params: Optional[RivaTTSService.InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the deprecated FastPitch TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to Female-1 voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for FastPitch model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent RivaTTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
server=server,
|
||||
voice_id=voice_id,
|
||||
sample_rate=sample_rate,
|
||||
model_function_map=model_function_map,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
RivaTTSService = NvidiaTTSService
|
||||
FastPitchTTSService = NvidiaTTSService
|
||||
RIVA_TTS_TIMEOUT_SECS = NVIDIA_TTS_TIMEOUT_SECS
|
||||
|
||||
@@ -514,9 +514,11 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process a frame and flush audio if it's the end of a full response."""
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# When the LLM finishes responding, flush any remaining text in Sarvam's buffer
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
return await super().process_frame(frame, direction)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
|
||||
@@ -425,16 +425,13 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
pending_aggregation = self._text_aggregator.text
|
||||
# Flush any remaining text (including text waiting for lookahead)
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
|
||||
|
||||
# Reset aggregator state
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
|
||||
if pending_aggregation.text:
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(pending_aggregation.text, pending_aggregation.type)
|
||||
)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -539,17 +536,20 @@ class TTSService(AIService):
|
||||
text = frame.text
|
||||
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
aggregated_by = "token"
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
else:
|
||||
aggregate = await self._text_aggregator.aggregate(frame.text)
|
||||
if aggregate:
|
||||
async for aggregate in self._text_aggregator.aggregate(frame.text):
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Awaitable, Callable, Optional
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
from websockets.protocol import State
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame
|
||||
@@ -137,6 +137,10 @@ class WebsocketService(ABC):
|
||||
# Normal closure, don't retry
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
break
|
||||
except ConnectionClosedError as e:
|
||||
# Error closure, don't retry
|
||||
logger.warning(f"{self} connection closed, but with an error: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
|
||||
@@ -6,31 +6,14 @@
|
||||
|
||||
"""Base notifier interface for Pipecat."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
class BaseNotifier(ABC):
|
||||
"""Abstract base class for notification mechanisms.
|
||||
|
||||
Provides a standard interface for implementing notification and waiting
|
||||
patterns used for event coordination and signaling between components
|
||||
in the Pipecat framework.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify(self):
|
||||
"""Send a notification signal.
|
||||
|
||||
Implementations should trigger any waiting coroutines or processes
|
||||
that are blocked on this notifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def wait(self):
|
||||
"""Wait for a notification signal.
|
||||
|
||||
Implementations should block until a notification is received
|
||||
from the corresponding notify() call.
|
||||
"""
|
||||
pass
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -6,40 +6,14 @@
|
||||
|
||||
"""Event-based notifier implementation using asyncio Event primitives."""
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
|
||||
class EventNotifier(BaseNotifier):
|
||||
"""Event-based notifier using asyncio.Event for task synchronization.
|
||||
|
||||
Provides a simple notification mechanism where one task can signal
|
||||
an event and other tasks can wait for that event to occur. The event
|
||||
is automatically cleared after each wait operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event notifier.
|
||||
|
||||
Creates an internal asyncio.Event for managing notifications.
|
||||
"""
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def notify(self):
|
||||
"""Signal the event to notify waiting tasks.
|
||||
|
||||
Sets the internal event, causing any tasks waiting on this
|
||||
notifier to be awakened.
|
||||
"""
|
||||
self._event.set()
|
||||
|
||||
async def wait(self):
|
||||
"""Wait for the event to be signaled.
|
||||
|
||||
Blocks until another task calls notify(). Automatically clears
|
||||
the event after being awakened so subsequent calls will wait
|
||||
for the next notification.
|
||||
"""
|
||||
await self._event.wait()
|
||||
self._event.clear()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ comprehensive monitoring and cleanup capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Dict, Optional, Sequence
|
||||
@@ -162,7 +163,9 @@ class TaskManager(BaseTaskManager):
|
||||
# Re-raise the exception to ensure the task is cancelled.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{name}: unexpected exception: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.error(f"{name} unexpected exception ({last.filename}:{last.lineno}): {e}")
|
||||
|
||||
if not self._params:
|
||||
raise Exception("TaskManager is not setup: unable to get event loop")
|
||||
@@ -197,9 +200,17 @@ class TaskManager(BaseTaskManager):
|
||||
# Here are sure the task is cancelled properly.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{name}: unexpected exception while cancelling task: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.error(
|
||||
f"{name} unexpected exception while cancelling task ({last.filename}:{last.lineno}): {e}"
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.critical(f"{name}: fatal base exception while cancelling task: {e}")
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
last = tb[-1]
|
||||
logger.critical(
|
||||
f"{name} fatal base exception while cancelling task ({last.filename}:{last.lineno}): {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def current_tasks(self) -> Sequence[asyncio.Task]:
|
||||
|
||||
@@ -203,7 +203,7 @@ def parse_start_end_tags(
|
||||
class TextPartForConcatenation:
|
||||
"""Class representing a part of text for concatenation with concatenate_aggregated_text.
|
||||
|
||||
Attributes:
|
||||
Parameters:
|
||||
text: The text content.
|
||||
includes_inter_part_spaces: Whether any necessary inter-frame
|
||||
(leading/trailing) spaces are already included in the text.
|
||||
|
||||
0
src/pipecat/utils/sync/__init__.py
Normal file
0
src/pipecat/utils/sync/__init__.py
Normal file
36
src/pipecat/utils/sync/base_notifier.py
Normal file
36
src/pipecat/utils/sync/base_notifier.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base notifier interface for Pipecat."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseNotifier(ABC):
|
||||
"""Abstract base class for notification mechanisms.
|
||||
|
||||
Provides a standard interface for implementing notification and waiting
|
||||
patterns used for event coordination and signaling between components
|
||||
in the Pipecat framework.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify(self):
|
||||
"""Send a notification signal.
|
||||
|
||||
Implementations should trigger any waiting coroutines or processes
|
||||
that are blocked on this notifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def wait(self):
|
||||
"""Wait for a notification signal.
|
||||
|
||||
Implementations should block until a notification is received
|
||||
from the corresponding notify() call.
|
||||
"""
|
||||
pass
|
||||
45
src/pipecat/utils/sync/event_notifier.py
Normal file
45
src/pipecat/utils/sync/event_notifier.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Event-based notifier implementation using asyncio Event primitives."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class EventNotifier(BaseNotifier):
|
||||
"""Event-based notifier using asyncio.Event for task synchronization.
|
||||
|
||||
Provides a simple notification mechanism where one task can signal
|
||||
an event and other tasks can wait for that event to occur. The event
|
||||
is automatically cleared after each wait operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event notifier.
|
||||
|
||||
Creates an internal asyncio.Event for managing notifications.
|
||||
"""
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def notify(self):
|
||||
"""Signal the event to notify waiting tasks.
|
||||
|
||||
Sets the internal event, causing any tasks waiting on this
|
||||
notifier to be awakened.
|
||||
"""
|
||||
self._event.set()
|
||||
|
||||
async def wait(self):
|
||||
"""Wait for the event to be signaled.
|
||||
|
||||
Blocks until another task calls notify(). Automatically clears
|
||||
the event after being awakened so subsequent calls will wait
|
||||
for the next notification.
|
||||
"""
|
||||
await self._event.wait()
|
||||
self._event.clear()
|
||||
@@ -14,7 +14,7 @@ aggregated text should be sent for speech synthesis.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
@@ -80,33 +80,43 @@ class BaseTextAggregator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate the specified text with the currently accumulated text.
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate the specified text and yield completed aggregations.
|
||||
|
||||
This method should be implemented to define how the new text contributes
|
||||
to the aggregation process. It returns the aggregated text and a string
|
||||
describing how it was aggregated if it's ready to be processed,
|
||||
or None otherwise.
|
||||
This method processes the input text character-by-character internally
|
||||
and yields Aggregation objects as they complete.
|
||||
|
||||
Subclasses should implement their specific logic for:
|
||||
|
||||
- How to combine new text with existing accumulated text
|
||||
- How to process text character-by-character
|
||||
- When to consider the aggregated text ready for processing
|
||||
- What criteria determine text completion (e.g., sentence boundaries)
|
||||
- When a completion occurs, the method should return an Aggregation object
|
||||
containing the aggregated text and its type. The text should be stripped
|
||||
of leading/trailing whitespace so that consumers can rely on a consistent
|
||||
format.
|
||||
- When a completion occurs, yield an Aggregation object containing the
|
||||
aggregated text (stripped of leading/trailing whitespace) and its type
|
||||
|
||||
Args:
|
||||
text: The text to be aggregated.
|
||||
|
||||
Yields:
|
||||
Aggregation objects as they complete. Each Aggregation consists of
|
||||
the aggregated text (stripped of leading/trailing whitespace) and
|
||||
a string indicating the type of aggregation (e.g., 'sentence', 'word',
|
||||
'token', 'my_custom_aggregation').
|
||||
"""
|
||||
pass
|
||||
# Make this a generator to satisfy type checker
|
||||
yield # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def flush(self) -> Optional[Aggregation]:
|
||||
"""Flush any pending aggregation.
|
||||
|
||||
This method is called at the end of a stream (e.g., when receiving
|
||||
LLMFullResponseEndFrame) to return any text that was buffered.
|
||||
|
||||
Returns:
|
||||
An Aggregation object if ready for processing, or None if more
|
||||
text is needed before the aggregated content is ready. If an Aggregation
|
||||
object is returned, it should consist of the updated aggregated text,
|
||||
stripped of leading/trailing whitespace, and a string indicating the
|
||||
type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
|
||||
An Aggregation object if there is pending text, or None if there
|
||||
is no pending text.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ support for custom handlers and configurable actions for when a pattern is found
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import AsyncIterator, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class MatchAction(Enum):
|
||||
@@ -26,15 +26,15 @@ class MatchAction(Enum):
|
||||
|
||||
Parameters:
|
||||
REMOVE: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
KEEP: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
AGGREGATE: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
@@ -72,7 +72,7 @@ class PatternMatch(Aggregation):
|
||||
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
|
||||
|
||||
|
||||
class PatternPairAggregator(BaseTextAggregator):
|
||||
class PatternPairAggregator(SimpleTextAggregator):
|
||||
"""Aggregator that identifies and processes content between pattern pairs.
|
||||
|
||||
This aggregator buffers text until it can identify complete pattern pairs
|
||||
@@ -97,9 +97,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
Text buffering and pattern detection will begin when text is aggregated.
|
||||
"""
|
||||
self._text = ""
|
||||
super().__init__()
|
||||
self._patterns = {}
|
||||
self._handlers = {}
|
||||
self._last_processed_position = 0 # Track where we last checked for complete patterns
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
@@ -132,17 +133,15 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
Args:
|
||||
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
action: What to do when a complete pattern is matched:
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
action: What to do when a complete pattern is matched.
|
||||
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as normal text. This allows you to register handlers for the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate aggregation object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
@@ -218,14 +217,18 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._handlers[type] = handler
|
||||
return self
|
||||
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
|
||||
"""Process all complete pattern pairs in the text.
|
||||
async def _process_complete_patterns(
|
||||
self, text: str, last_processed_position: int = 0
|
||||
) -> Tuple[List[PatternMatch], str]:
|
||||
"""Process newly complete pattern pairs in the text.
|
||||
|
||||
Searches for all complete pattern pairs in the text, calls the
|
||||
appropriate handlers, and optionally removes the matches.
|
||||
Searches for pattern pairs that have been completed since last_processed_position,
|
||||
calls the appropriate handlers, and optionally removes the matches.
|
||||
|
||||
Args:
|
||||
text: The text to process for pattern matches.
|
||||
last_processed_position: The position in text that was already processed.
|
||||
Only patterns that end at or after this position will be processed.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_matches, processed_text) where:
|
||||
@@ -259,17 +262,23 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
content=content.strip(), type=type, full_match=full_match
|
||||
)
|
||||
|
||||
# Call the appropriate handler if registered
|
||||
if type in self._handlers:
|
||||
# Check if this pattern was already processed
|
||||
already_processed = match.end() <= last_processed_position
|
||||
|
||||
# Only call handler for newly completed patterns
|
||||
if not already_processed and type in self._handlers:
|
||||
try:
|
||||
await self._handlers[type](pattern_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pattern handler for {type}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
# Handle pattern based on action
|
||||
if action == MatchAction.REMOVE:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
# Remove patterns are only removed once (when newly completed)
|
||||
if not already_processed:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
else:
|
||||
# KEEP/AGGREGATE patterns stay in all_matches
|
||||
all_matches.append(pattern_match)
|
||||
|
||||
return all_matches, processed_text
|
||||
@@ -305,76 +314,84 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
return None
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[PatternMatch]:
|
||||
async def aggregate(self, text: str) -> AsyncIterator[PatternMatch]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete pattern
|
||||
pairs, and returns processed text up to sentence boundaries if possible.
|
||||
If there are incomplete patterns (start without matching end), it will
|
||||
continue buffering text.
|
||||
Processes the input text character-by-character, handles pattern pairs,
|
||||
and uses the parent's lookahead logic for sentence detection when no
|
||||
patterns are active.
|
||||
|
||||
Args:
|
||||
text: New text to add to the buffer.
|
||||
text: Text to aggregate.
|
||||
|
||||
Returns:
|
||||
Processed text up to a sentence boundary, or None if more
|
||||
text is needed to form a complete sentence or pattern.
|
||||
Yields:
|
||||
PatternMatch objects as patterns complete or sentences are detected.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
# Process any complete patterns in the buffer
|
||||
patterns, processed_text = await self._process_complete_patterns(self._text)
|
||||
# Process any newly complete patterns in the buffer
|
||||
# Only patterns that complete after _last_processed_position will trigger handlers
|
||||
patterns, processed_text = await self._process_complete_patterns(
|
||||
self._text, self._last_processed_position
|
||||
)
|
||||
|
||||
self._text = processed_text
|
||||
# Update the last processed position to prevent re-processing patterns
|
||||
# This tracks where in the buffer we've already called handlers, so we
|
||||
# only trigger handlers once when a pattern completes
|
||||
self._last_processed_position = len(self._text)
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
self._text = processed_text
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
yield patterns[0]
|
||||
continue
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, continue
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
continue
|
||||
# For AGGREGATE patterns: yield any text before the pattern starts
|
||||
# This ensures text doesn't get stuck in the buffer waiting for sentence
|
||||
# boundaries when a pattern begins (e.g., "Here is code <code>..." yields "Here is code")
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
yield PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
continue
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
return None
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
if eos_marker:
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
# Use parent's lookahead logic for sentence detection
|
||||
aggregation = await super()._check_sentence_with_lookahead(char)
|
||||
if aggregation:
|
||||
# Convert to PatternMatch for consistency with return type
|
||||
yield PatternMatch(
|
||||
content=aggregation.text, type=aggregation.type, full_match=aggregation.text
|
||||
)
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the buffer.
|
||||
"""Handle interruptions by clearing the buffer and pattern state.
|
||||
|
||||
Called when an interruption occurs in the processing pipeline,
|
||||
to reset the state and discard any partially aggregated text.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().handle_interruption()
|
||||
self._last_processed_position = 0
|
||||
# Pattern and handler state persists across interruptions
|
||||
|
||||
async def reset(self):
|
||||
"""Clear the internally aggregated text.
|
||||
@@ -382,4 +399,6 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
Resets the aggregator to its initial state, discarding any
|
||||
buffered text and clearing pattern tracking state.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().reset()
|
||||
self._last_processed_position = 0
|
||||
# Pattern and handler state persists across resets
|
||||
|
||||
@@ -11,9 +11,9 @@ until it finds an end-of-sentence marker, making it suitable for basic TTS
|
||||
text processing scenarios.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.string import SENTENCE_ENDING_PUNCTUATION, match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
Creates an empty text buffer ready to begin accumulating text tokens.
|
||||
"""
|
||||
self._text = ""
|
||||
self._needs_lookahead: bool = False
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
@@ -41,30 +42,87 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text and return completed sentences.
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text and yield completed sentences.
|
||||
|
||||
Adds the new text to the buffer and checks for end-of-sentence markers.
|
||||
When a sentence boundary is found, returns the completed sentence and
|
||||
removes it from the buffer.
|
||||
Processes the input text character-by-character. When sentence-ending
|
||||
punctuation is detected, it waits for non-whitespace lookahead before
|
||||
calling NLTK. This prevents false positives like "$29." being detected
|
||||
as a sentence when it's actually "$29.95".
|
||||
|
||||
Args:
|
||||
text: New text to add to the aggregation buffer.
|
||||
text: Text to aggregate.
|
||||
|
||||
Yields:
|
||||
Complete sentences as Aggregation objects.
|
||||
"""
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
# Check for sentence with lookahead
|
||||
result = await self._check_sentence_with_lookahead(char)
|
||||
if result:
|
||||
yield result
|
||||
|
||||
async def _check_sentence_with_lookahead(self, char: str) -> Optional[Aggregation]:
|
||||
"""Check for sentence boundaries using lookahead logic.
|
||||
|
||||
This method implements the core sentence detection logic with lookahead.
|
||||
When sentence-ending punctuation is detected, it waits for the next
|
||||
non-whitespace character before calling NLTK. This disambiguates cases
|
||||
like "$29." (not a sentence) vs "$29. Next" (sentence ends at period).
|
||||
Whitespace alone is not meaningful lookahead since it appears in both
|
||||
cases. Instead, the first non-whitespace character after the punctuation
|
||||
is used to confirm the sentence boundary.
|
||||
|
||||
Subclasses can call this via super() to reuse the lookahead behavior
|
||||
while adding their own logic (e.g., tag handling, pattern matching).
|
||||
|
||||
Args:
|
||||
char: The most recently added character (used for lookahead check).
|
||||
|
||||
Returns:
|
||||
A complete sentence if an end-of-sentence marker is found,
|
||||
or None if more text is needed to complete a sentence.
|
||||
Aggregation if sentence found, None otherwise.
|
||||
"""
|
||||
result: Optional[str] = None
|
||||
# If we need lookahead, check if we now have non-whitespace
|
||||
if self._needs_lookahead:
|
||||
# Check if the new character is non-whitespace
|
||||
if char.strip():
|
||||
# We have meaningful lookahead, call NLTK
|
||||
self._needs_lookahead = False
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
|
||||
self._text += text
|
||||
if eos_marker:
|
||||
# NLTK confirmed a sentence - return it
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return Aggregation(text=result, type=AggregationType.SENTENCE)
|
||||
# No sentence found - keep accumulating
|
||||
return None
|
||||
# Still whitespace, keep waiting
|
||||
return None
|
||||
|
||||
eos_end_marker = match_endofsentence(self._text)
|
||||
if eos_end_marker:
|
||||
result = self._text[:eos_end_marker]
|
||||
self._text = self._text[eos_end_marker:]
|
||||
# Check if we just added sentence-ending punctuation
|
||||
if self._text and self._text[-1] in SENTENCE_ENDING_PUNCTUATION:
|
||||
# Mark that we need lookahead (don't call NLTK yet)
|
||||
self._needs_lookahead = True
|
||||
|
||||
if result:
|
||||
return None
|
||||
|
||||
async def flush(self) -> Optional[Aggregation]:
|
||||
"""Flush any remaining text in the buffer.
|
||||
|
||||
Returns any text remaining in the buffer. This is called at the end
|
||||
of a stream to ensure all text is processed.
|
||||
|
||||
Returns:
|
||||
Any remaining text as a sentence, or None if buffer is empty.
|
||||
"""
|
||||
if self._text:
|
||||
# Return whatever we have in the buffer
|
||||
result = self._text
|
||||
await self.reset()
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
return None
|
||||
|
||||
@@ -75,6 +133,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
discarding any partially accumulated text.
|
||||
"""
|
||||
self._text = ""
|
||||
self._needs_lookahead = False
|
||||
|
||||
async def reset(self):
|
||||
"""Clear the internally aggregated text.
|
||||
@@ -83,3 +142,4 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
any accumulated text content.
|
||||
"""
|
||||
self._text = ""
|
||||
self._needs_lookahead = False
|
||||
|
||||
@@ -11,13 +11,14 @@ between specified start/end tag pairs, ensuring that tagged content is processed
|
||||
as a unit regardless of internal punctuation.
|
||||
"""
|
||||
|
||||
from typing import Optional, Sequence
|
||||
from typing import AsyncIterator, Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.string import StartEndTags, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class SkipTagsAggregator(BaseTextAggregator):
|
||||
class SkipTagsAggregator(SimpleTextAggregator):
|
||||
"""Aggregator that prevents end of sentence matching between start/end tags.
|
||||
|
||||
This aggregator buffers text until it finds an end of sentence or a start
|
||||
@@ -37,67 +38,59 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
tags: Sequence of StartEndTags objects defining the tag pairs
|
||||
that should prevent sentence boundary detection.
|
||||
"""
|
||||
self._text = ""
|
||||
super().__init__()
|
||||
self._tags = tags
|
||||
self._current_tag: Optional[StartEndTags] = None
|
||||
self._current_tag_index: int = 0
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently buffered text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text while respecting tag boundaries.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete
|
||||
pattern pairs, and returns processed text up to sentence boundaries if
|
||||
possible. If there are incomplete patterns (start without matching
|
||||
end), it will continue buffering text.
|
||||
Processes the input text character-by-character, updates tag state, and
|
||||
uses the parent's lookahead logic for sentence detection when not
|
||||
inside tags.
|
||||
|
||||
Args:
|
||||
text: New text to add to the buffer.
|
||||
text: Text to aggregate.
|
||||
|
||||
Returns:
|
||||
An Aggregation object containing text up to a sentence boundary and
|
||||
marked as SENTENCE type or None if more text is needed to complete a
|
||||
sentence or close tags.
|
||||
Yields:
|
||||
Aggregation objects containing text up to a sentence boundary,
|
||||
marked as SENTENCE type.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
|
||||
self._text, self._tags, self._current_tag, self._current_tag_index
|
||||
)
|
||||
# Update tag state
|
||||
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
|
||||
self._text, self._tags, self._current_tag, self._current_tag_index
|
||||
)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
if not self._current_tag:
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
if eos_marker:
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
# If inside tags, don't check for sentences
|
||||
if self._current_tag:
|
||||
continue
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
# Otherwise, use parent's lookahead logic for sentence detection
|
||||
result = await super()._check_sentence_with_lookahead(char)
|
||||
if result:
|
||||
yield result
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the buffer.
|
||||
"""Handle interruptions by clearing the buffer and tag state.
|
||||
|
||||
Called when an interruption occurs in the processing pipeline,
|
||||
to reset the state and discard any partially aggregated text.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().handle_interruption()
|
||||
self._current_tag = None
|
||||
self._current_tag_index = 0
|
||||
|
||||
async def reset(self):
|
||||
"""Clear the internally aggregated text.
|
||||
"""Clear the internally aggregated text and tag state.
|
||||
|
||||
Resets the aggregator to its initial state, discarding any
|
||||
buffered text.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().reset()
|
||||
self._current_tag = None
|
||||
self._current_tag_index = 0
|
||||
|
||||
@@ -483,7 +483,9 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
# Add all available attributes to the span
|
||||
attribute_kwargs = {
|
||||
"service_name": service_class_name,
|
||||
"model": getattr(self, "model_name", "unknown"),
|
||||
"model": getattr(
|
||||
self, getattr(self, "_full_model_name", "model_name"), "unknown"
|
||||
),
|
||||
"stream": True, # Most LLM services use streaming
|
||||
"parameters": params,
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.extensions.ivr.ivr_navigator import IVRProcessor
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMTextFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -334,10 +335,12 @@ class TestIVRNavigation(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
frames_to_send = [
|
||||
LLMTextFrame(text="Hello, I'm trying to reach billing."),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMTextFrame, # Should pass through unchanged
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
|
||||
expected_up_frames = [
|
||||
|
||||
@@ -38,14 +38,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
||||
|
||||
async def test_pattern_match_and_removal(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
text = "Hello <test>pattern content</test>!"
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.test_handler.assert_called_once()
|
||||
@@ -55,28 +49,37 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
self.assertEqual(result.text, "Hello !")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
# No results yet (waiting for lookahead after "!")
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Next sentence should be processed separately. Spaces around the sentence
|
||||
# should be stripped in the returned Aggregation.
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
# Next sentence should provide the lookahead and trigger the previous sentence
|
||||
async for result in self.aggregator.aggregate(" This is another sentence."):
|
||||
results.append(result)
|
||||
|
||||
# First result should be "Hello !" triggered by the space lookahead
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "Hello !")
|
||||
self.assertEqual(results[0].type, "sentence")
|
||||
|
||||
# Now flush to get the remaining sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_pattern_match_and_aggregate(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
self.assertEqual(result.text, "Here is code")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code_pattern")
|
||||
text = "Here is code <code>pattern content</code> This is another sentence."
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# First result should be "Here is code" when pattern starts
|
||||
self.assertEqual(results[0].text, "Here is code")
|
||||
self.assertEqual(results[0].type, "sentence")
|
||||
|
||||
# Second result should be the code pattern content
|
||||
self.assertEqual(results[1].text, "pattern content")
|
||||
self.assertEqual(results[1].type, "code_pattern")
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.code_handler.assert_called_once()
|
||||
@@ -85,11 +88,9 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(call_args.type, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(result.text, "pattern content")
|
||||
self.assertEqual(result.type, "code_pattern")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
# Last sentence needs flush (waiting for lookahead after ".")
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
@@ -97,11 +98,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern content")
|
||||
|
||||
text = "Hello <test>pattern content"
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
# No complete pattern yet, so nothing should be returned
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# The handler should not be called yet
|
||||
self.test_handler.assert_not_called()
|
||||
@@ -136,9 +136,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
self.aggregator.on_pattern_match("emphasis", emphasis_handler)
|
||||
|
||||
# Test with multiple patterns in one text block
|
||||
text = "Hello <voice>female</voice> I am <em>very</em> excited to meet you!"
|
||||
result = await self.aggregator.aggregate(text)
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# Both handlers should be called with correct data
|
||||
voice_handler.assert_called_once()
|
||||
@@ -151,6 +150,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(emphasis_match.type, "emphasis")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
|
||||
# With lookahead, we need to flush to get the final sentence
|
||||
self.assertEqual(len(results), 0) # Waiting for lookahead after "!"
|
||||
|
||||
result = await self.aggregator.flush()
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
|
||||
@@ -158,9 +161,9 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_handle_interruption(self):
|
||||
# Start with incomplete pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
text = "Hello <test>pattern"
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Simulate interruption
|
||||
await self.aggregator.handle_interruption()
|
||||
@@ -172,20 +175,18 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_not_called()
|
||||
|
||||
async def test_pattern_across_sentences(self):
|
||||
# Test pattern that spans multiple sentences
|
||||
result = await self.aggregator.aggregate("Hello <test>This is sentence one.")
|
||||
|
||||
# First sentence contains start of pattern but no end, so no complete pattern yet
|
||||
self.assertIsNone(result)
|
||||
|
||||
# Add second part with pattern end
|
||||
result = await self.aggregator.aggregate(" This is sentence two.</test> Final sentence.")
|
||||
text = "Hello <test>This is sentence one. This is sentence two.</test> Final sentence."
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# Handler should be called with entire content
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
||||
|
||||
# With lookahead, we need to flush to get the final sentence
|
||||
self.assertEqual(len(results), 0) # Waiting for lookahead after "."
|
||||
|
||||
result = await self.aggregator.flush()
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
|
||||
|
||||
@@ -14,22 +14,112 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.aggregator = SimpleTextAggregator()
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
text = "Hello "
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet
|
||||
assert len(results) == 0
|
||||
assert self.aggregator.text.text == "Hello"
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
text = "Hello Pipecat!"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead after "!")
|
||||
assert len(results) == 0
|
||||
|
||||
# Flush to get the pending sentence
|
||||
aggregate = await self.aggregator.flush()
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
# Aggregators should strip leading/trailing spaces when returning text
|
||||
assert self.aggregator.text.text == "How are"
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == "How are you?"
|
||||
text = "Hello Pipecat! How are you?"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# First sentence should be complete (lookahead from "H" confirmed it)
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello Pipecat!"
|
||||
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "How are you?"
|
||||
|
||||
async def test_lookahead_decimal_number(self):
|
||||
"""Test that $29.95 is not split at $29."""
|
||||
text = "Ask me for only $29.95/month."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead after final ".")
|
||||
assert len(results) == 0
|
||||
|
||||
# Can use flush() to get the pending sentence at end of stream
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "Ask me for only $29.95/month."
|
||||
|
||||
async def test_lookahead_abbreviation(self):
|
||||
"""Test that Mr. Smith is not split at Mr."""
|
||||
text = "Hello Mr. Smith."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead after final ".")
|
||||
assert len(results) == 0
|
||||
|
||||
# Can use flush() to get the pending sentence at end of stream
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "Hello Mr. Smith."
|
||||
|
||||
async def test_lookahead_actual_sentence_end(self):
|
||||
"""Test that a real sentence end is detected after lookahead."""
|
||||
text = "Hello world. Next sentence"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# First sentence should be complete (lookahead from "N" confirmed it)
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello world."
|
||||
|
||||
async def test_flush_pending_sentence(self):
|
||||
"""Test that flush() returns pending sentence waiting for lookahead."""
|
||||
text = "Hello world."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead)
|
||||
assert len(results) == 0
|
||||
|
||||
# Call flush to get it
|
||||
result = await self.aggregator.flush()
|
||||
assert result is not None
|
||||
assert result.text == "Hello world."
|
||||
# Flush again should return None
|
||||
assert await self.aggregator.flush() == None
|
||||
|
||||
async def test_flush_with_no_pending(self):
|
||||
"""Test that flush() returns any remaining text in buffer."""
|
||||
text = "Hello"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences
|
||||
assert len(results) == 0
|
||||
|
||||
result = await self.aggregator.flush()
|
||||
# flush() now returns any remaining text, not just pending lookahead
|
||||
assert result is not None
|
||||
assert result.text == "Hello"
|
||||
# Buffer should be empty after flush
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_flush_after_lookahead_confirmed(self):
|
||||
"""Test flush after lookahead has already confirmed sentence."""
|
||||
text = "Hello. W"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# First sentence should be complete (lookahead from "W" confirmed it)
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello."
|
||||
|
||||
# flush() returns any remaining text (the "W" in this case)
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "W"
|
||||
|
||||
@@ -17,7 +17,14 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# No tags involved, aggregate at end of sentence.
|
||||
result = await self.aggregator.aggregate("Hello Pipecat!")
|
||||
text = "Hello Pipecat!"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# Should still be waiting for lookahead after "!"
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "Hello Pipecat!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
@@ -26,7 +33,14 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, avoid aggregation during tags.
|
||||
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
|
||||
text = "My email is <spell>foo@pipecat.ai</spell>."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# Should still be waiting for lookahead after "."
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
@@ -34,25 +48,17 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_streaming_tags(self):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, stream small chunk of texts.
|
||||
result = await self.aggregator.aggregate("My email is <sp")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <sp")
|
||||
# Tags involved
|
||||
text = "My email is <spell>foo.bar@pipecat.ai</spell>."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
result = await self.aggregator.aggregate("ell>foo.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
|
||||
|
||||
result = await self.aggregator.aggregate("bar@pipecat.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
|
||||
|
||||
result = await self.aggregator.aggregate("ai</spe")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
# Should still be waiting for lookahead after "."
|
||||
self.assertEqual(len(results), 0)
|
||||
self.assertEqual(self.aggregator.text.text, text)
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
result = await self.aggregator.aggregate("ll>.")
|
||||
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, text)
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
34
uv.lock
generated
34
uv.lock
generated
@@ -36,12 +36,12 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "aic-sdk"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/99/83/bf38b95d98c67b8ebc574fb4a4f23c07a3740b51992d7522976173d30b98/aic_sdk-1.1.0.tar.gz", hash = "sha256:04e08df695581c8cb4db8acca20e73815e9f449e7bd08e0162fd55518c727963", size = 34954, upload-time = "2025-11-11T20:45:24.25Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/ba/3ebe31b91e03d42437ec864e9d2af3a52b7ccc73a1a0c1026275956270b0/aic_sdk-1.2.0.tar.gz", hash = "sha256:eeda9a181c679f175dbe6f0efc0c67ec98ff3d84cfe01541fef7fa12ecd505ca", size = 35606, upload-time = "2025-11-20T14:42:14.333Z" }
|
||||
|
||||
[[package]]
|
||||
name = "aioboto3"
|
||||
@@ -4496,6 +4496,9 @@ google = [
|
||||
{ name = "google-genai" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
gradium = [
|
||||
{ name = "websockets" },
|
||||
]
|
||||
groq = [
|
||||
{ name = "groq" },
|
||||
]
|
||||
@@ -4564,6 +4567,9 @@ neuphonic = [
|
||||
noisereduce = [
|
||||
{ name = "noisereduce" },
|
||||
]
|
||||
nvidia = [
|
||||
{ name = "nvidia-riva-client" },
|
||||
]
|
||||
openai = [
|
||||
{ name = "websockets" },
|
||||
]
|
||||
@@ -4652,6 +4658,7 @@ dev = [
|
||||
{ name = "ruff" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "setuptools-scm" },
|
||||
{ name = "towncrier" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
@@ -4666,7 +4673,7 @@ docs = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'moondream'", specifier = "~=1.10.0" },
|
||||
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.1.0" },
|
||||
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.2.0" },
|
||||
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.5.0" },
|
||||
{ name = "aiofiles", specifier = ">=24.1.0,<25" },
|
||||
{ name = "aiohttp", specifier = ">=3.11.12,<4" },
|
||||
@@ -4706,7 +4713,7 @@ requires-dist = [
|
||||
{ name = "noisereduce", marker = "extra == 'noisereduce'", specifier = "~=3.0.3" },
|
||||
{ name = "numba", specifier = "==0.61.2" },
|
||||
{ name = "numpy", specifier = ">=1.26.4,<3" },
|
||||
{ name = "nvidia-riva-client", marker = "extra == 'riva'", specifier = "~=2.21.1" },
|
||||
{ name = "nvidia-riva-client", marker = "extra == 'nvidia'", specifier = "~=2.21.1" },
|
||||
{ name = "onnxruntime", marker = "extra == 'local-smart-turn-v3'", specifier = ">=1.20.1,<2" },
|
||||
{ name = "onnxruntime", marker = "extra == 'silero'", specifier = ">=1.20.1,<2" },
|
||||
{ name = "openai", specifier = ">=1.74.0,<3" },
|
||||
@@ -4717,6 +4724,7 @@ requires-dist = [
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'tracing'", specifier = ">=1.33.0" },
|
||||
{ name = "ormsgpack", marker = "extra == 'fish'", specifier = "~=1.7.0" },
|
||||
{ name = "pillow", specifier = ">=11.1.0,<12" },
|
||||
{ name = "pipecat-ai", extras = ["nvidia"], marker = "extra == 'riva'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'assemblyai'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'asyncai'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'aws'" },
|
||||
@@ -4726,6 +4734,7 @@ requires-dist = [
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'fish'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'gladia'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'google'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'gradium'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'heygen'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'lmnt'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'neuphonic'" },
|
||||
@@ -4767,7 +4776,7 @@ requires-dist = [
|
||||
{ name = "wait-for2", marker = "python_full_version < '3.12'", specifier = ">=0.4.1" },
|
||||
{ name = "websockets", marker = "extra == 'websockets-base'", specifier = ">=13.1,<16.0" },
|
||||
]
|
||||
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "nim", "noisereduce", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
|
||||
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "gradium", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "noisereduce", "nvidia", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -4784,6 +4793,7 @@ dev = [
|
||||
{ name = "ruff", specifier = ">=0.12.11,<1" },
|
||||
{ name = "setuptools", specifier = "~=78.1.1" },
|
||||
{ name = "setuptools-scm", specifier = "~=8.3.1" },
|
||||
{ name = "towncrier", specifier = "~=25.8.0" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "sphinx", specifier = ">=8.1.3" },
|
||||
@@ -7243,6 +7253,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/7b/30d423bdb2546250d719d7821aaf9058cc093d165565b245b159c788a9dd/torchvision-0.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e5d680162694fac4c8a374954e261ddfb4eb0ce103287b0f693e4e9c579ef957", size = 1638621, upload-time = "2025-04-23T14:41:46.06Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "towncrier"
|
||||
version = "25.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c2/eb/5bf25a34123698d3bbab39c5bc5375f8f8bcbcc5a136964ade66935b8b9d/towncrier-25.8.0.tar.gz", hash = "sha256:eef16d29f831ad57abb3ae32a0565739866219f1ebfbdd297d32894eb9940eb1", size = 76322, upload-time = "2025-08-30T11:41:55.393Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/42/06/8ba22ec32c74ac1be3baa26116e3c28bc0e76a5387476921d20b6fdade11/towncrier-25.8.0-py3-none-any.whl", hash = "sha256:b953d133d98f9aeae9084b56a3563fd2519dfc6ec33f61c9cd2c61ff243fb513", size = 65101, upload-time = "2025-08-30T11:41:53.644Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.67.1"
|
||||
|
||||
Reference in New Issue
Block a user