Compare commits
3 Commits
v0.0.97
...
hush/TurnT
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bbfa829d3 | ||
|
|
c2eb663bdc | ||
|
|
bf055843e6 |
174
.github/workflows/generate-changelog.yml
vendored
174
.github/workflows/generate-changelog.yml
vendored
@@ -1,174 +0,0 @@
|
||||
name: Generate Changelog for Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "Release version (e.g., 0.0.97)"
|
||||
required: true
|
||||
type: string
|
||||
date:
|
||||
description: "Release date (YYYY-MM-DD format, defaults to today)"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
generate-changelog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Set release date
|
||||
id: set_date
|
||||
run: |
|
||||
if [ -z "${{ inputs.date }}" ]; then
|
||||
RELEASE_DATE=$(date +%Y-%m-%d)
|
||||
echo "Using today's date: $RELEASE_DATE"
|
||||
else
|
||||
RELEASE_DATE="${{ inputs.date }}"
|
||||
echo "Using provided date: $RELEASE_DATE"
|
||||
fi
|
||||
echo "release_date=$RELEASE_DATE" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Validate inputs
|
||||
run: |
|
||||
# Validate version format (basic check)
|
||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+.*$ ]]; then
|
||||
echo "Error: Version must be in format X.Y.Z (e.g., 0.0.97)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate date format if provided
|
||||
if [ -n "${{ inputs.date }}" ]; then
|
||||
if ! date -d "${{ inputs.date }}" >/dev/null 2>&1; then
|
||||
# Try macOS date format
|
||||
if ! date -j -f "%Y-%m-%d" "${{ inputs.date }}" >/dev/null 2>&1; then
|
||||
echo "Error: Date must be in YYYY-MM-DD format (e.g., 2025-12-04)"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Check for changelog fragments
|
||||
id: check_fragments
|
||||
run: |
|
||||
FRAGMENT_COUNT=$(find changelog -name "*.md" ! -name "_template.md.j2" | wc -l | tr -d ' ')
|
||||
echo "fragment_count=$FRAGMENT_COUNT" >> $GITHUB_OUTPUT
|
||||
|
||||
if [ "$FRAGMENT_COUNT" -eq "0" ]; then
|
||||
echo "❌ Error: No changelog fragments found in changelog/"
|
||||
echo ""
|
||||
echo "Cannot create a release without changelog entries."
|
||||
echo "Add changelog fragments to the changelog/ directory (e.g., 1234.added.md) and try again."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate fragment types
|
||||
VALID_TYPES="added changed deprecated removed fixed security"
|
||||
INVALID_FRAGMENTS=""
|
||||
|
||||
for file in changelog/*.md; do
|
||||
# Skip template
|
||||
if [[ "$file" == "changelog/_template.md.j2" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Extract type from filename (e.g., 1234.added.md -> added)
|
||||
filename=$(basename "$file")
|
||||
# Handle both 1234.added.md and 1234.added.2.md patterns
|
||||
type=$(echo "$filename" | sed -E 's/^[0-9]+\.([a-z]+)(\.[0-9]+)?\.md$/\1/')
|
||||
|
||||
# Check if type is valid
|
||||
if ! echo "$VALID_TYPES" | grep -wq "$type"; then
|
||||
INVALID_FRAGMENTS="$INVALID_FRAGMENTS\n - $filename (type: '$type')"
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -n "$INVALID_FRAGMENTS" ]; then
|
||||
echo "❌ Error: Invalid changelog fragment types found:"
|
||||
echo -e "$INVALID_FRAGMENTS"
|
||||
echo ""
|
||||
echo "Valid types are: $VALID_TYPES"
|
||||
echo "Example: 1234.added.md, 5678.fixed.md"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Found $FRAGMENT_COUNT changelog fragment(s)"
|
||||
echo "has_fragments=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Preview changelog
|
||||
run: |
|
||||
echo "## Preview of changelog for version ${{ inputs.version }}"
|
||||
echo ""
|
||||
uv run towncrier build --draft --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}"
|
||||
|
||||
- name: Build changelog
|
||||
run: |
|
||||
uv run towncrier build --version "${{ inputs.version }}" --date "${{ steps.set_date.outputs.release_date }}" --yes
|
||||
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: "Update changelog for version ${{ inputs.version }}"
|
||||
title: "Release ${{ inputs.version }} - Changelog Update"
|
||||
body: |
|
||||
## Changelog Update for Release ${{ inputs.version }}
|
||||
|
||||
This PR updates the CHANGELOG.md with all changes for version **${{ inputs.version }}**.
|
||||
|
||||
### Summary
|
||||
- **Version:** ${{ inputs.version }}
|
||||
- **Date:** ${{ steps.set_date.outputs.release_date }}
|
||||
- **Fragments processed:** ${{ steps.check_fragments.outputs.fragment_count }}
|
||||
|
||||
### What this PR does
|
||||
- ✅ Adds new release section to CHANGELOG.md
|
||||
- ✅ Removes processed changelog fragments
|
||||
- ✅ Ready to merge for release
|
||||
|
||||
### Next Steps
|
||||
1. Review the changelog entries below
|
||||
2. Make any necessary edits to CHANGELOG.md if needed
|
||||
3. Merge this PR
|
||||
4. Continue with your release process
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>📋 Preview of changes</summary>
|
||||
|
||||
The changelog has been updated with entries from the following fragments:
|
||||
|
||||
```bash
|
||||
${{ steps.check_fragments.outputs.fragment_count }} fragments processed
|
||||
```
|
||||
|
||||
</details>
|
||||
branch: changelog-${{ inputs.version }}
|
||||
delete-branch: true
|
||||
labels: |
|
||||
changelog
|
||||
release
|
||||
152
CHANGELOG.md
152
CHANGELOG.md
@@ -5,145 +5,10 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
<!-- towncrier release notes start -->
|
||||
|
||||
## [0.0.97] - 2025-12-05
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added new Gradium services, `GradiumSTTService` and `GradiumTTSService`, for
|
||||
speech-to-text and text-to-speech functionality using Gradium's API.
|
||||
|
||||
- Additions for `AsyncAITTSService` and `AsyncAIHttpTTSService`:
|
||||
|
||||
- Added new `languages`: `pt`, `nl`, `ar`, `ru`, `ro`, `ja`, `he`, `hy`,
|
||||
`tr`, `hi`, `zh`.
|
||||
- Updated the default model to `asyncflow_multilingual_v1.0` for improved
|
||||
accuracy and broader language coverage.
|
||||
|
||||
- Added optional tool and tool output filters for MCP services.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Deepgram logging to include Deepgram request IDs for improved
|
||||
debugging.
|
||||
|
||||
- Text Aggregation Improvements:
|
||||
|
||||
- **Breaking Change**: `BaseTextAggregator.aggregate()` now returns
|
||||
`AsyncIterator[Aggregation]` instead of `Optional[Aggregation]`. This
|
||||
enables the aggregator to return multiple results based on the provided
|
||||
text.
|
||||
- Refactored text aggregators to use inheritance: `SkipTagsAggregator` and
|
||||
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
|
||||
the base class's sentence detection logic.
|
||||
|
||||
- Improved interruption handling to prevent bots from repeating themselves. LLM
|
||||
services that return multiple sentences in a single response (e.g.,
|
||||
`GoogleLLMService`) are now split into individual sentences before being sent
|
||||
to TTS. This ensures interruptions occur at sentence boundaries, preventing
|
||||
the bot from repeating content after being interrupted during long responses.
|
||||
|
||||
- Updated `AICFilter` to use Quail STT as the default model
|
||||
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
|
||||
interaction (e.g., voice agents, speech-to-text) and operates at a native
|
||||
sample rate of 16 kHz with fixed enhancement parameters.
|
||||
|
||||
- If an unexpected exception is caught, or if `FrameProcessor.push_error()` is
|
||||
called with an exception, the file name and line number where the exception
|
||||
occured are now logged.
|
||||
|
||||
- Updated Smart Turn model weights to v3.1.
|
||||
|
||||
- Smart Turn analyzer now uses the full context of the turn rather than just
|
||||
the audio since VAD last triggered.
|
||||
|
||||
- Updated `CartesiaSTTService` to return the full transcription `result` in the
|
||||
`TranscriptionFrame` and `InterimTranscriptionFrame`. This provides access to
|
||||
word timestamp data.
|
||||
|
||||
- `HumeTTSService` changes:
|
||||
|
||||
- Added tracking headers (`X-Hume-Client-Name` and `X-Hume-Client-Version`)
|
||||
to all requests made by `HumeTTSService` to the Hume API for better usage
|
||||
tracking and analytics.
|
||||
- Added `stop()` and `cancel()` cleanup methods to `HumeTTSService` to
|
||||
properly close the HTTP client and prevent resource leaks.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- NVIDIA Services name changes (all functionality is unchanged):
|
||||
|
||||
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
|
||||
- `RivaSTTService` is now deprecated, use `NvidiaSTTService` instead.
|
||||
- `RivaTTSService` is now deprecated, use `NvidiaTTSService` instead.
|
||||
- Use `uv pip install pipecat-ai[nvidia]` instead of
|
||||
`uv pip install pipecat-ai[riva]`
|
||||
|
||||
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
|
||||
has any effect. Noise gating is now handled automatically by the AIC VAD
|
||||
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
|
||||
|
||||
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
|
||||
multiple times for `KEEP` or `AGGREGATE` patterns.
|
||||
|
||||
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
|
||||
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
|
||||
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was always
|
||||
set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
|
||||
spoken. Now, audio is flushed when the TTS services receives the
|
||||
`LLMFullResponseEndFrame` or `EndFrame`.
|
||||
|
||||
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
|
||||
incorrectly pushed after a functional call. This caused an issue with the
|
||||
voice-ui-kit's conversational panel rending of the LLM output after a
|
||||
function call.
|
||||
|
||||
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
|
||||
services.
|
||||
|
||||
- Fixed an issue that caused `WebsocketService` instances to attempt
|
||||
reconnection during shutdown.
|
||||
|
||||
- Fixed an issue in `ElevenLabsTTSService` where character usage metrics were
|
||||
only reported on the first TTS generation per turn.
|
||||
|
||||
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃
|
||||
|
||||
### Added
|
||||
|
||||
- Added `AWSBedrockAgentCoreProcessor` to support invoking an AgentCore-hosted
|
||||
agent in a Pipecat pipeline.
|
||||
|
||||
- Enhanced error handling across the framework:
|
||||
|
||||
- Added `on_error` callback to `FrameProcessor` for centralized error
|
||||
handling.
|
||||
|
||||
- Renamed `push_error(error: ErrorFrame)` to `push_error_frame(error: ErrorFrame)`
|
||||
for clarity.
|
||||
|
||||
- Added new `push_error` method for simplified error reporting:
|
||||
|
||||
```python
|
||||
async def push_error(error_msg: str,
|
||||
exception: Optional[Exception] = None,
|
||||
fatal: bool = False)
|
||||
```
|
||||
|
||||
- Standardized error logging by replacing `logger.exception` calls with
|
||||
`logger.error` throughout the codebase.
|
||||
|
||||
- Added `cache_read_input_tokens`, `cache_creation_input_tokens` and
|
||||
`reasoning_tokens` to OTel spans for LLM call
|
||||
|
||||
- Added `LiveKitRESTHelper` utility class for managing LiveKit rooms via REST API.
|
||||
|
||||
- Added `DeepgramSageMakerSTTService` which connects to a SageMaker hosted
|
||||
@@ -223,18 +88,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added new emotions: calm and fluent
|
||||
|
||||
- Added `enable_logging` to `SimliVideoService` input parameters. It's disabled
|
||||
by default.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `FishAudioTTSService` default model to `s1`.
|
||||
|
||||
- Updated `DeepgramTTSService` to use Deepgram's TTS websocket API. ⚠️ This is
|
||||
a potential breaking change, which only affects you if you're self-hosting
|
||||
`DeepgramTTSService`. The new service uses Websockets and improves TTFB
|
||||
latency.
|
||||
|
||||
- Updated `daily-python` to 0.22.0.
|
||||
|
||||
- `BaseTextAggregator` changes:
|
||||
@@ -392,11 +247,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `AWSBedrockLLMService` where the `aws_region` arg was
|
||||
always set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue with `DeepgramFluxSTTService` where it sometimes failed to reconnect.
|
||||
|
||||
- Fixed an issue in `ElevenLabsRealtimeSTTService` where dynamic language
|
||||
updates were not working.
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel
|
||||
|
||||
**Examples:**
|
||||
|
||||
- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py)
|
||||
- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py)
|
||||
- [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py)
|
||||
|
||||
#### Key requirements:
|
||||
|
||||
105
CONTRIBUTING.md
105
CONTRIBUTING.md
@@ -17,121 +17,24 @@ We welcome contributions of all kinds! Your help is appreciated. Follow these st
|
||||
git checkout -b your-branch-name
|
||||
```
|
||||
4. **Make your changes**: Edit or add files as necessary.
|
||||
5. **Add a changelog entry**: Create a changelog fragment file (see [Changelog Entries](#changelog-entries) below).
|
||||
6. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
|
||||
7. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
|
||||
5. **Test your changes**: Ensure that your changes look correct and follow the style set in the codebase.
|
||||
6. **Commit your changes**: Once you're satisfied with your changes, commit them with a meaningful message.
|
||||
|
||||
```bash
|
||||
git commit -m "Description of your changes"
|
||||
```
|
||||
|
||||
8. **Push your changes**: Push your branch to your forked repository.
|
||||
7. **Push your changes**: Push your branch to your forked repository.
|
||||
|
||||
```bash
|
||||
git push origin your-branch-name
|
||||
```
|
||||
|
||||
9. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
|
||||
8. **Submit a Pull Request (PR)**: Open a PR from your forked repository to the main branch of this repo.
|
||||
> Important: Describe the changes you've made clearly!
|
||||
|
||||
Our maintainers will review your PR, and once everything is good, your contributions will be merged!
|
||||
|
||||
## Changelog Entries
|
||||
|
||||
Every pull request that makes a user-facing change should include a changelog entry. We use a changelog fragment system to avoid merge conflicts.
|
||||
|
||||
### Creating a Changelog Fragment
|
||||
|
||||
1. Create a new file in the `changelog/` directory with this naming pattern:
|
||||
|
||||
```
|
||||
<PR_number>.<type>.md
|
||||
```
|
||||
|
||||
2. Choose the appropriate type:
|
||||
|
||||
- `added.md` - New features
|
||||
- `changed.md` - Changes in existing functionality
|
||||
- `deprecated.md` - Soon-to-be removed features
|
||||
- `removed.md` - Removed features
|
||||
- `fixed.md` - Bug fixes
|
||||
- `security.md` - Security fixes
|
||||
|
||||
3. Write your changelog entry as a Markdown bullet point. Include the `-` at the start:
|
||||
|
||||
**Example files:**
|
||||
|
||||
`changelog/1234.added.md`:
|
||||
|
||||
```markdown
|
||||
- Added support for Anthropic Claude 3.5 Sonnet with improved streaming performance.
|
||||
```
|
||||
|
||||
`changelog/5678.fixed.md`:
|
||||
|
||||
```markdown
|
||||
- Fixed an issue where audio frames were dropped during high-load scenarios.
|
||||
```
|
||||
|
||||
**For entries with nested bullets:**
|
||||
|
||||
`changelog/1234.changed.md`:
|
||||
|
||||
```markdown
|
||||
- Updated service configuration:
|
||||
|
||||
- Changed default timeout to 30 seconds
|
||||
- Added retry logic for failed connections
|
||||
```
|
||||
|
||||
### Multiple Changes in One PR
|
||||
|
||||
**Different types of changes:** Create separate fragment files for each type:
|
||||
|
||||
```
|
||||
changelog/1234.added.md
|
||||
changelog/1234.fixed.md
|
||||
```
|
||||
|
||||
**Multiple changes of the same type:** Create numbered fragment files:
|
||||
|
||||
```
|
||||
changelog/1234.changed.md
|
||||
changelog/1234.changed.2.md
|
||||
```
|
||||
|
||||
**Related changes:** Use nested bullets in a single fragment:
|
||||
|
||||
```markdown
|
||||
- Updated service configuration:
|
||||
|
||||
- Changed default timeout to 30 seconds
|
||||
- Added retry logic for failed connections
|
||||
```
|
||||
|
||||
**Rule of thumb:** One logical change per fragment file. If changes are unrelated, use separate files.
|
||||
|
||||
### Preview Your Changes
|
||||
|
||||
To see what your changelog entry will look like:
|
||||
|
||||
```bash
|
||||
towncrier build --draft --version Unreleased
|
||||
```
|
||||
|
||||
This won't modify any files, just show you a preview.
|
||||
|
||||
### When to Skip Changelog Entries
|
||||
|
||||
You can skip adding a changelog entry for:
|
||||
|
||||
- Documentation-only changes
|
||||
- Internal refactoring with no user-facing impact
|
||||
- Test-only changes
|
||||
- CI/build configuration changes
|
||||
|
||||
If you're unsure whether your change needs a changelog entry, ask in your PR!
|
||||
|
||||
## Dependency Management
|
||||
|
||||
This project uses [uv](https://docs.astral.sh/uv/) for dependency management. The `uv.lock` file is committed to ensure reproducible builds.
|
||||
|
||||
@@ -74,9 +74,9 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
{% for section, _ in sections.items() %}
|
||||
{% if sections[section] %}
|
||||
{% for category, val in definitions.items() if category in sections[section]%}
|
||||
### {{ definitions[category]['name'] }}
|
||||
|
||||
{% for text, values in sections[section][category].items() %}
|
||||
{{ text }}
|
||||
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
No significant changes.
|
||||
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
103
docs/TURN_AWARE_TRANSCRIPT_PROCESSOR.md
Normal file
103
docs/TURN_AWARE_TRANSCRIPT_PROCESSOR.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# TurnAwareTranscriptProcessor Example
|
||||
|
||||
## Overview
|
||||
|
||||
The `TurnAwareTranscriptProcessor` combines user and assistant transcript tracking with turn boundary detection. It correctly handles interruptions by only capturing what was actually spoken.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
|
||||
|
||||
# Create the processor
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Register event handlers
|
||||
@turn_processor.event_handler("on_turn_started")
|
||||
async def handle_turn_started(processor, turn_number):
|
||||
print(f"Turn {turn_number} started")
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, was_interrupted):
|
||||
print(f"\nTurn {turn_number} ended:")
|
||||
print(f" User said: {user_text}")
|
||||
print(f" Assistant said: {assistant_text}")
|
||||
print(f" Was interrupted: {was_interrupted}")
|
||||
|
||||
@turn_processor.event_handler("on_transcript_update")
|
||||
async def handle_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
print(f"[{msg.role}]: {msg.content}")
|
||||
|
||||
# Add to pipeline
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
turn_processor, # Process transcripts and track turns
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
1. **Turn Boundary Detection**: Automatically detects when turns start and end based on user and bot speaking patterns
|
||||
2. **Interruption Handling**: Correctly captures only what was actually spoken when interruptions occur
|
||||
3. **Real-time Transcripts**: Emits transcript messages for both user and assistant speech
|
||||
4. **Turn Events**: Provides start/end events with accumulated transcripts for each turn
|
||||
|
||||
## Events
|
||||
|
||||
### on_turn_started
|
||||
Emitted when a new turn begins (user starts speaking).
|
||||
|
||||
**Handler signature**: `async def handler(processor, turn_number)`
|
||||
|
||||
### on_turn_ended
|
||||
Emitted when a turn ends with accumulated transcripts.
|
||||
|
||||
**Handler signature**: `async def handler(processor, turn_number, user_transcript, assistant_transcript, was_interrupted)`
|
||||
|
||||
### on_transcript_update
|
||||
Inherited from `BaseTranscriptProcessor`, emitted for individual transcript messages.
|
||||
|
||||
**Handler signature**: `async def handler(processor, frame)`
|
||||
|
||||
## Turn Logic
|
||||
|
||||
- Turns start when the user begins speaking (`UserStartedSpeakingFrame`)
|
||||
- Turns end when:
|
||||
- The user starts speaking again (previous turn ends, new turn starts)
|
||||
- The bot is interrupted (`InterruptionFrame`)
|
||||
- The pipeline ends (`EndFrame`/`CancelFrame`)
|
||||
|
||||
## Integration with OpenTelemetry
|
||||
|
||||
You can use turn events to enrich OpenTelemetry spans:
|
||||
|
||||
```python
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
|
||||
turn_tracker = TurnTrackingObserver()
|
||||
turn_tracer = TurnTraceObserver(turn_tracker)
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def add_transcripts_to_span(processor, turn_number, user_text, assistant_text, interrupted):
|
||||
# Get current span and add transcript data
|
||||
from opentelemetry import trace
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("turn.user_text", user_text)
|
||||
current_span.set_attribute("turn.assistant_text", assistant_text)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The processor handles async frame processing correctly by delaying turn end until frames are processed
|
||||
- Works with word-level timestamps from TTS services like Cartesia
|
||||
- Accumulates both user (`TranscriptionFrame`) and assistant (`TTSTextFrame`) speech
|
||||
- Emits individual transcript messages in addition to turn-level aggregation
|
||||
@@ -119,6 +119,7 @@ def import_core_modules():
|
||||
"pipecat.observers",
|
||||
"pipecat.runner",
|
||||
"pipecat.serializers",
|
||||
"pipecat.sync",
|
||||
"pipecat.transcriptions",
|
||||
"pipecat.utils",
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ Quick Links
|
||||
Runner <api/pipecat.runner>
|
||||
Serializers <api/pipecat.serializers>
|
||||
Services <api/pipecat.services>
|
||||
Sync <api/pipecat.sync>
|
||||
Transcriptions <api/pipecat.transcriptions>
|
||||
Transports <api/pipecat.transports>
|
||||
Utils <api/pipecat.utils>
|
||||
Utils <api/pipecat.utils>
|
||||
@@ -73,9 +73,6 @@ GOOGLE_CLOUD_PROJECT_ID=...
|
||||
GOOGLE_CLOUD_LOCATION=...
|
||||
GOOGLE_TEST_CREDENTIALS=...
|
||||
|
||||
# Gradium
|
||||
GRAPDIUM_API_KEY=...
|
||||
|
||||
# Grok
|
||||
GROK_API_KEY=...
|
||||
|
||||
@@ -194,4 +191,4 @@ TWILIO_AUTH_TOKEN=...
|
||||
WHATSAPP_TOKEN=...
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN=...
|
||||
WHATSAPP_PHONE_NUMBER_ID=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
@@ -15,7 +15,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.services.riva.tts import FastPitchTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -36,7 +36,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
task = PipelineTask(
|
||||
Pipeline([tts, transport.output()]),
|
||||
@@ -1,127 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.services.gradium.tts import GradiumTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
|
||||
|
||||
tts = GradiumTTSService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
voice_id="YTpq7expH9539ERJ",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -22,9 +22,9 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.services.nvidia.stt import NvidiaSTTService
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.riva.stt import RivaSTTService
|
||||
from pipecat.services.riva.tts import RivaTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -59,13 +59,11 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
stt = RivaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
|
||||
)
|
||||
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
|
||||
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -76,7 +76,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = FireworksLLMService(
|
||||
api_key=os.getenv("FIREWORKS_API_KEY"),
|
||||
model="accounts/fireworks/models/gpt-oss-20b",
|
||||
model="accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -75,11 +75,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = NvidiaLLMService(
|
||||
llm = NimLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"),
|
||||
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Recommended when turning thinking off
|
||||
params=NvidiaLLMService.InputParams(temperature=0.0),
|
||||
params=NimLLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
@@ -14,13 +14,20 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
|
||||
from pipecat.frames.frames import (
|
||||
LLMRunFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
TranscriptionMessage,
|
||||
)
|
||||
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
|
||||
@@ -19,6 +19,7 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
|
||||
@@ -28,10 +28,10 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@@ -45,11 +45,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -46,11 +46,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -47,11 +47,11 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -64,14 +64,11 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except:
|
||||
pass
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
return None
|
||||
|
||||
@@ -91,23 +88,6 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# full list of tools available from rijksmuseum MCP:
|
||||
# - get_artwork_details
|
||||
# - get_artwork_image
|
||||
# - get_user_sets
|
||||
# - get_user_set_details
|
||||
# - open_image_in_browser
|
||||
# - get_artist_timeline
|
||||
|
||||
mcp_tools_filter = ["get_artwork_details", "get_artwork_image", "open_image_in_browser"]
|
||||
|
||||
|
||||
def open_image_output_filter(output: str):
|
||||
pattern = r"Successfully opened image in browser: "
|
||||
text_to_print = re.sub(pattern, "", output)
|
||||
print(f"🖼️ link to high resolution artwork: {text_to_print}")
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -156,10 +136,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
),
|
||||
# Optional
|
||||
tools_filter=mcp_tools_filter, # Optional
|
||||
tools_output_filters={"open_image_in_browser": open_image_output_filter},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
|
||||
@@ -67,14 +67,13 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except:
|
||||
pass
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
return None
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
try:
|
||||
|
||||
@@ -45,7 +45,7 @@ Source = "https://github.com/pipecat-ai/pipecat"
|
||||
Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.2.0" ]
|
||||
aic = [ "aic-sdk~=1.1.0" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
@@ -55,7 +55,7 @@ azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cerebras = []
|
||||
daily = [ "daily-python~=0.22.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0", "pipecat-ai[websockets-base]" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
deepseek = []
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
@@ -63,7 +63,6 @@ fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ]
|
||||
gradium = [ "pipecat-ai[websockets-base]" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
@@ -84,8 +83,8 @@ mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
nim = []
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
nvidia = [ "nvidia-riva-client~=2.21.1" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
openrouter = []
|
||||
@@ -94,7 +93,7 @@ playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
remote-smart-turn = []
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "pipecat-ai[nvidia]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
|
||||
sambanova = []
|
||||
@@ -130,7 +129,6 @@ dev = [
|
||||
"setuptools~=78.1.1",
|
||||
"setuptools_scm~=8.3.1",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"towncrier~=25.8.0",
|
||||
]
|
||||
|
||||
docs = [
|
||||
@@ -161,7 +159,7 @@ where = ["src"]
|
||||
"src/pipecat/audio/dtmf/dtmf-star.wav",
|
||||
]
|
||||
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.1-cpu.onnx"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
@@ -208,44 +206,3 @@ convention = "google"
|
||||
command_line = "--module pytest"
|
||||
source = ["src"]
|
||||
omit = ["*/tests/*"]
|
||||
|
||||
[tool.towncrier]
|
||||
package = "pipecat"
|
||||
package_dir = "src"
|
||||
filename = "CHANGELOG.md"
|
||||
directory = "changelog"
|
||||
start_string = "<!-- towncrier release notes start -->\n"
|
||||
template = "changelog/_template.md.j2"
|
||||
title_format = "## [{version}] - {project_date}"
|
||||
underlines = ["", "", ""]
|
||||
wrap = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "added"
|
||||
name = "Added"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "changed"
|
||||
name = "Changed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "deprecated"
|
||||
name = "Deprecated"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "removed"
|
||||
name = "Removed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "fixed"
|
||||
name = "Fixed"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "security"
|
||||
name = "Security"
|
||||
showcontent = true
|
||||
|
||||
@@ -103,7 +103,7 @@ TESTS_07 = [
|
||||
("07o-interruptible-assemblyai.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime-http.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-nvidia.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-riva-nim.py", EVAL_SIMPLE_MATH),
|
||||
("07s-interruptible-google-audio-in.py", EVAL_SIMPLE_MATH),
|
||||
("07t-interruptible-fish.py", EVAL_SIMPLE_MATH),
|
||||
("07v-interruptible-neuphonic.py", EVAL_SIMPLE_MATH),
|
||||
@@ -136,7 +136,7 @@ TESTS_14 = [
|
||||
("14g-function-calling-grok.py", EVAL_WEATHER),
|
||||
("14h-function-calling-azure.py", EVAL_WEATHER),
|
||||
("14i-function-calling-fireworks.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nvidia.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nim.py", EVAL_WEATHER),
|
||||
("14k-function-calling-cerebras.py", EVAL_WEATHER),
|
||||
("14m-function-calling-openrouter.py", EVAL_WEATHER),
|
||||
("14n-function-calling-perplexity.py", EVAL_WEATHER),
|
||||
|
||||
@@ -39,7 +39,7 @@ class AICFilter(BaseAudioFilter):
|
||||
self,
|
||||
*,
|
||||
license_key: str = "",
|
||||
model_type: AICModelType = AICModelType.QUAIL_STT,
|
||||
model_type: AICModelType = AICModelType.QUAIL_L,
|
||||
enhancement_level: Optional[float] = 1.0,
|
||||
voice_gain: Optional[float] = 1.0,
|
||||
noise_gate_enable: Optional[bool] = True,
|
||||
@@ -52,27 +52,12 @@ class AICFilter(BaseAudioFilter):
|
||||
enhancement_level: Optional overall enhancement strength (0.0..1.0).
|
||||
voice_gain: Optional linear gain applied to detected speech (0.0..4.0).
|
||||
noise_gate_enable: Optional enable/disable noise gate (default: True).
|
||||
|
||||
.. deprecated:: 1.3.0
|
||||
The `noise_gate_enable` parameter is deprecated and no longer has any effect.
|
||||
It will be removed in a future version.
|
||||
"""
|
||||
self._license_key = license_key
|
||||
self._model_type = model_type
|
||||
|
||||
self._enhancement_level = enhancement_level
|
||||
self._voice_gain = voice_gain
|
||||
if noise_gate_enable is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter `noise_gate_enable` is deprecated and no longer has any effect. "
|
||||
"It will be removed in a future version. Use AIC VAD instead (create_vad_analyzer()).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._noise_gate_enable = noise_gate_enable
|
||||
|
||||
self._enabled = True
|
||||
@@ -164,6 +149,10 @@ class AICFilter(BaseAudioFilter):
|
||||
)
|
||||
if self._voice_gain is not None:
|
||||
self._aic.set_parameter(AICParameter.VOICE_GAIN, float(self._voice_gain))
|
||||
if self._noise_gate_enable is not None:
|
||||
self._aic.set_parameter(
|
||||
AICParameter.NOISE_GATE_ENABLE, 1.0 if bool(self._noise_gate_enable) else 0.0
|
||||
)
|
||||
|
||||
self._aic_ready = True
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from pipecat.metrics.metrics import MetricsData, SmartTurnMetricsData
|
||||
STOP_SECS = 3
|
||||
PRE_SPEECH_MS = 0
|
||||
MAX_DURATION_SECONDS = 8 # Max allowed segment duration
|
||||
USE_ONLY_LAST_VAD_SEGMENT = True
|
||||
|
||||
|
||||
class SmartTurnParams(BaseTurnParams):
|
||||
@@ -42,6 +43,8 @@ class SmartTurnParams(BaseTurnParams):
|
||||
stop_secs: float = STOP_SECS
|
||||
pre_speech_ms: float = PRE_SPEECH_MS
|
||||
max_duration_secs: float = MAX_DURATION_SECONDS
|
||||
# not exposing this for now yet until the model can handle it.
|
||||
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
|
||||
|
||||
|
||||
class SmartTurnTimeoutException(Exception):
|
||||
@@ -157,7 +160,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
state, result = await loop.run_in_executor(
|
||||
self._executor, self._process_speech_segment, self._audio_buffer
|
||||
)
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT:
|
||||
self._clear(state)
|
||||
logger.debug(f"End of Turn result: {state}")
|
||||
return state, result
|
||||
|
||||
Binary file not shown.
@@ -42,15 +42,17 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
Args:
|
||||
smart_turn_model_path: Path to the ONNX model file. If this is not
|
||||
set, the bundled smart-turn-v3.1-cpu model will be used.
|
||||
set, the bundled smart-turn-v3.0 model will be used.
|
||||
cpu_count: The number of CPUs to use for inference. Defaults to 1.
|
||||
**kwargs: Additional arguments passed to BaseSmartTurn.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
logger.debug("Loading Local Smart Turn v3 model...")
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.1-cpu.onnx"
|
||||
model_name = "smart-turn-v3.0.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
@@ -68,8 +70,6 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
impresources.files(package_path).joinpath(model_name)
|
||||
)
|
||||
|
||||
logger.debug(f"Loading Local Smart Turn v3.x model from {smart_turn_model_path}...")
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
@@ -79,7 +79,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
logger.debug("Loaded Local Smart Turn v3")
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
@@ -18,10 +18,8 @@ from loguru import logger
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMTextFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -151,18 +149,11 @@ class IVRProcessor(FrameProcessor):
|
||||
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
# Process text through the pattern aggregator
|
||||
async for result in self._aggregator.aggregate(frame.text):
|
||||
result = await self._aggregator.aggregate(frame.text)
|
||||
if result:
|
||||
# Push aggregated text that doesn't contain XML patterns
|
||||
await self.push_frame(LLMTextFrame(result.text), direction)
|
||||
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
# Flush any remaining text from the aggregator
|
||||
remaining = await self._aggregator.flush()
|
||||
if remaining:
|
||||
await self.push_frame(LLMTextFrame(remaining.text), direction)
|
||||
# Push the end frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -40,8 +40,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
|
||||
|
||||
class NotifierGate(FrameProcessor):
|
||||
|
||||
@@ -330,7 +330,7 @@ class TextFrame(DataFrame):
|
||||
"""
|
||||
|
||||
text: str
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
skip_tts: bool = field(init=False)
|
||||
# Whether any necessary inter-frame (leading/trailing) spaces are already
|
||||
# included in the text.
|
||||
# NOTE: Ideally this would be available at init time with a default value,
|
||||
@@ -343,7 +343,7 @@ class TextFrame(DataFrame):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = None
|
||||
self.skip_tts = False
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
@@ -835,13 +835,11 @@ class ErrorFrame(SystemFrame):
|
||||
error: Description of the error that occurred.
|
||||
fatal: Whether the error is fatal and requires bot shutdown.
|
||||
processor: The frame processor that generated the error.
|
||||
exception: The exception that occurred.
|
||||
"""
|
||||
|
||||
error: str
|
||||
fatal: bool = False
|
||||
processor: Optional["FrameProcessor"] = None
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(error: {self.error}, fatal: {self.fatal})"
|
||||
@@ -1632,22 +1630,22 @@ class LLMFullResponseStartFrame(ControlFrame):
|
||||
more TextFrames and a final LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = None
|
||||
self.skip_tts = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseEndFrame(ControlFrame):
|
||||
"""Frame indicating the end of an LLM response."""
|
||||
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = None
|
||||
self.skip_tts = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class GatedLLMContextAggregator(FrameProcessor):
|
||||
|
||||
@@ -83,7 +83,8 @@ class LLMTextProcessor(FrameProcessor):
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
async for aggregation in self._text_aggregator.aggregate(in_frame.text):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
@@ -91,13 +92,15 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: Optional[bool] = None):
|
||||
# Flush any remaining text
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=remaining.text,
|
||||
aggregated_by=remaining.type,
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -126,4 +126,6 @@ class WakeCheckFilter(FrameProcessor):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error in wake word filter: {e}", exception=e)
|
||||
error_msg = f"Error in wake word filter: {e}"
|
||||
logger.exception(error_msg)
|
||||
await self.push_error(ErrorFrame(error_msg))
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Awaitable, Callable, Tuple, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class WakeNotifierFilter(FrameProcessor):
|
||||
|
||||
@@ -12,7 +12,6 @@ management, and frame flow control mechanisms.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Sequence, Tuple, Type
|
||||
@@ -143,7 +142,6 @@ class FrameProcessor(BaseObject):
|
||||
- on_after_process_frame: Called after a frame is processed
|
||||
- on_before_push_frame: Called before a frame is pushed
|
||||
- on_after_push_frame: Called after a frame is pushed
|
||||
- on_error: Called when an error is raised in the frame processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -236,7 +234,6 @@ class FrameProcessor(BaseObject):
|
||||
self._register_event_handler("on_after_process_frame", sync=True)
|
||||
self._register_event_handler("on_before_push_frame", sync=True)
|
||||
self._register_event_handler("on_after_push_frame", sync=True)
|
||||
self._register_event_handler("on_error", sync=True)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -633,43 +630,7 @@ class FrameProcessor(BaseObject):
|
||||
elif isinstance(frame, (FrameProcessorResumeFrame, FrameProcessorResumeUrgentFrame)):
|
||||
await self.__resume(frame)
|
||||
|
||||
async def push_error(
|
||||
self,
|
||||
error_msg: str,
|
||||
exception: Optional[Exception] = None,
|
||||
fatal: bool = False,
|
||||
):
|
||||
"""Creates and pushes an ErrorFrame upstream.
|
||||
|
||||
Creates and pushes an ErrorFrame upstream to notify other processors in the
|
||||
pipeline about an error condition. The error frame will include context about
|
||||
which processor generated the error.
|
||||
|
||||
Args:
|
||||
error_msg: Descriptive message explaining the error condition.
|
||||
exception: Optional exception object that caused the error, if available.
|
||||
This provides additional context for debugging and error handling.
|
||||
fatal: Whether this error should be considered fatal to the pipeline.
|
||||
Fatal errors typically cause the entire pipeline to stop processing.
|
||||
Defaults to False for non-fatal errors.
|
||||
|
||||
Example::
|
||||
|
||||
```python
|
||||
# Non-fatal error
|
||||
await self.push_error("Failed to process audio chunk, skipping")
|
||||
|
||||
# Fatal error with exception context
|
||||
try:
|
||||
result = some_critical_operation()
|
||||
except Exception as e:
|
||||
await self.push_error("Critical operation failed", exception=e, fatal=True)
|
||||
```
|
||||
"""
|
||||
error_frame = ErrorFrame(error=error_msg, fatal=fatal, exception=exception, processor=self)
|
||||
await self.push_error_frame(error=error_frame)
|
||||
|
||||
async def push_error_frame(self, error: ErrorFrame):
|
||||
async def push_error(self, error: ErrorFrame):
|
||||
"""Push an error frame upstream.
|
||||
|
||||
Args:
|
||||
@@ -677,18 +638,6 @@ class FrameProcessor(BaseObject):
|
||||
"""
|
||||
if not error.processor:
|
||||
error.processor = self
|
||||
await self._call_event_handler("on_error", error)
|
||||
|
||||
if error.exception:
|
||||
tb = traceback.extract_tb(error.exception.__traceback__)
|
||||
last = tb[-1]
|
||||
error_message = (
|
||||
f"{error.processor} exception ({last.filename}:{last.lineno}): {error.error}"
|
||||
)
|
||||
else:
|
||||
error_message = f"{error.processor} error: {error.error}"
|
||||
|
||||
logger.error(error_message)
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -810,10 +759,8 @@ class FrameProcessor(BaseObject):
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Uncaught exception handling _start_interruption: {e}",
|
||||
exception=e,
|
||||
)
|
||||
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Internal method to push frames to adjacent processors.
|
||||
@@ -850,7 +797,8 @@ class FrameProcessor(BaseObject):
|
||||
await self._observer.on_push_frame(data)
|
||||
await self._prev.queue_frame(frame, direction)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Uncaught exception: {e}", exception=e)
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
def _check_started(self, frame: Frame):
|
||||
"""Check if the processor has been started.
|
||||
@@ -926,7 +874,8 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_process_frame", frame)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error processing frame: {e}", exception=e)
|
||||
logger.exception(f"{self}: error processing frame: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
async def __input_frame_task_handler(self):
|
||||
"""Handle frames from the input queue.
|
||||
|
||||
@@ -24,7 +24,7 @@ try:
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables import Runnable
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
|
||||
logger.exception("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -113,6 +113,6 @@ class LangchainProcessor(FrameProcessor):
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -935,8 +935,8 @@ class RTVIObserverParams:
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
bot_output_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
|
||||
@@ -23,7 +23,7 @@ try:
|
||||
from strands import Agent
|
||||
from strands.multiagent.graph import Graph
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error("In order to use Strands Agents, you need to `pip install strands-agents`.")
|
||||
logger.exception("In order to use Strands Agents, you need to `pip install strands-agents`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
finally:
|
||||
if ttfb_tracking:
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
@@ -306,3 +308,267 @@ class TranscriptProcessor:
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TurnAwareTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes transcripts with turn boundary awareness.
|
||||
|
||||
This processor combines user and assistant transcript tracking with turn
|
||||
detection, emitting events when turns start and end. It correctly handles
|
||||
interruptions by only capturing what was actually spoken.
|
||||
|
||||
Turn boundaries are detected based on:
|
||||
- User started speaking (UserStartedSpeakingFrame)
|
||||
- Bot stopped speaking (BotStoppedSpeakingFrame)
|
||||
- Interruptions (InterruptionFrame)
|
||||
|
||||
Events:
|
||||
on_turn_started: Emitted when a new turn begins.
|
||||
Handler signature: async def handler(processor, turn_number)
|
||||
|
||||
on_turn_ended: Emitted when a turn ends.
|
||||
Handler signature: async def handler(processor, turn_number,
|
||||
user_transcript, assistant_transcript,
|
||||
was_interrupted)
|
||||
|
||||
on_transcript_update: Inherited from BaseTranscriptProcessor, emitted for
|
||||
individual transcript messages.
|
||||
|
||||
Example::
|
||||
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
@turn_processor.event_handler("on_turn_started")
|
||||
async def handle_turn_started(processor, turn_number):
|
||||
print(f"Turn {turn_number} started")
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, interrupted):
|
||||
print(f"Turn {turn_number} ended")
|
||||
print(f"User said: {user_text}")
|
||||
print(f"Assistant said: {assistant_text}")
|
||||
print(f"Was interrupted: {interrupted}")
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
turn_processor,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the turn-aware transcript processor.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Turn tracking state
|
||||
self._turn_number = 0
|
||||
self._turn_active = False
|
||||
self._turn_start_time: Optional[str] = None
|
||||
|
||||
# Accumulate text for current turn
|
||||
self._current_turn_user_parts: List[TextPartForConcatenation] = []
|
||||
self._current_turn_assistant_parts: List[TextPartForConcatenation] = []
|
||||
|
||||
# Track bot speaking state
|
||||
self._bot_is_speaking = False
|
||||
|
||||
# Register turn events
|
||||
self._register_event_handler("on_turn_started")
|
||||
self._register_event_handler("on_turn_ended")
|
||||
|
||||
async def _start_turn(self):
|
||||
"""Start a new turn."""
|
||||
if not self._turn_active:
|
||||
self._turn_number += 1
|
||||
self._turn_active = True
|
||||
self._turn_start_time = time_now_iso8601()
|
||||
self._current_turn_user_parts = []
|
||||
self._current_turn_assistant_parts = []
|
||||
|
||||
logger.debug(f"Turn {self._turn_number} started")
|
||||
await self._call_event_handler("on_turn_started", self._turn_number)
|
||||
|
||||
async def _end_turn(self, was_interrupted: bool = False):
|
||||
"""End the current turn and emit aggregated transcripts.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn ended due to an interruption.
|
||||
"""
|
||||
if not self._turn_active:
|
||||
return
|
||||
|
||||
# Aggregate user text
|
||||
user_transcript = ""
|
||||
if self._current_turn_user_parts:
|
||||
user_transcript = concatenate_aggregated_text(self._current_turn_user_parts)
|
||||
|
||||
# Aggregate assistant text
|
||||
assistant_transcript = ""
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_transcript = concatenate_aggregated_text(self._current_turn_assistant_parts)
|
||||
|
||||
# Emit turn ended event
|
||||
logger.debug(
|
||||
f"Turn {self._turn_number} ended (interrupted={was_interrupted}). "
|
||||
f"User: '{user_transcript}', Assistant: '{assistant_transcript}'"
|
||||
)
|
||||
await self._call_event_handler(
|
||||
"on_turn_ended",
|
||||
self._turn_number,
|
||||
user_transcript,
|
||||
assistant_transcript,
|
||||
was_interrupted,
|
||||
)
|
||||
|
||||
# Reset turn state
|
||||
self._turn_active = False
|
||||
self._current_turn_user_parts = []
|
||||
self._current_turn_assistant_parts = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for turn-aware transcript tracking.
|
||||
|
||||
Handles:
|
||||
- UserStartedSpeakingFrame: Start new turn
|
||||
- TranscriptionFrame: Accumulate user speech and emit transcript message
|
||||
- BotStartedSpeakingFrame: Track bot speaking state
|
||||
- TTSTextFrame: Accumulate assistant speech
|
||||
- BotStoppedSpeakingFrame: End turn if no interruption pending
|
||||
- InterruptionFrame: End turn immediately as interrupted
|
||||
- EndFrame/CancelFrame: End any active turn
|
||||
|
||||
Args:
|
||||
frame: Input frame to process.
|
||||
direction: Frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
# User started speaking
|
||||
if self._bot_is_speaking:
|
||||
# This is an interruption - end the current turn with what was spoken
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self._end_turn(was_interrupted=True)
|
||||
self._bot_is_speaking = False
|
||||
elif self._turn_active:
|
||||
# Previous turn is ending normally (bot finished speaking)
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self._end_turn(was_interrupted=False)
|
||||
|
||||
# Start a new turn
|
||||
await self._start_turn()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
# Accumulate user speech for the current turn
|
||||
if self._turn_active:
|
||||
self._current_turn_user_parts.append(
|
||||
TextPartForConcatenation(frame.text, includes_inter_part_spaces=True)
|
||||
)
|
||||
|
||||
# Also emit individual transcript message
|
||||
message = TranscriptionMessage(
|
||||
role="user",
|
||||
user_id=frame.user_id,
|
||||
content=frame.text,
|
||||
timestamp=frame.timestamp,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
# Bot started speaking
|
||||
self._bot_is_speaking = True
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
# Accumulate assistant speech for the current turn
|
||||
if self._turn_active:
|
||||
self._current_turn_assistant_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
# Bot stopped speaking - just mark it, don't end turn yet
|
||||
# Turn will end when next user speaks or pipeline ends
|
||||
self._bot_is_speaking = False
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
# Emit assistant transcript message with what was spoken before interruption
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(self._current_turn_assistant_parts)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
# Push frame first to ensure proper cleanup
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# End turn as interrupted
|
||||
await self._end_turn(was_interrupted=True)
|
||||
self._bot_is_speaking = False
|
||||
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
# Pipeline ending - finalize any active turn
|
||||
if self._turn_active:
|
||||
# Emit any pending assistant transcript (allow time for TTSTextFrames to be processed)
|
||||
# Give a brief moment for any pending frames to process
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
await self._end_turn(was_interrupted=isinstance(frame, CancelFrame))
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -302,7 +302,7 @@ def _setup_webrtc_routes(
|
||||
result: StartBotResult = {"sessionId": session_id}
|
||||
if request_data.get("enableDefaultIceServers"):
|
||||
result["iceConfig"] = IceConfig(
|
||||
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
|
||||
iceServers=[IceServer(urls="stun:stun.l.google.com:19302")]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -199,7 +199,7 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hang up Plivo call: {e}")
|
||||
logger.exception(f"Failed to hang up Plivo call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Plivo WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -225,7 +225,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hang up Telnyx call: {e}")
|
||||
logger.exception(f"Failed to hang up Telnyx call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Telnyx WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -236,7 +236,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hang up Twilio call: {e}")
|
||||
logger.exception(f"Failed to hang up Twilio call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Twilio WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -166,6 +166,6 @@ class AIService(FrameProcessor):
|
||||
async for f in generator:
|
||||
if f:
|
||||
if isinstance(f, ErrorFrame):
|
||||
await self.push_error_frame(f)
|
||||
await self.push_error(f)
|
||||
else:
|
||||
await self.push_frame(f)
|
||||
|
||||
@@ -458,7 +458,8 @@ class AnthropicLLMService(LLMService):
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"{e}"))
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -206,8 +206,9 @@ class AssemblyAISTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
self._connected = False
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -232,7 +233,8 @@ class AssemblyAISTTService(STTService):
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
@@ -240,7 +242,8 @@ class AssemblyAISTTService(STTService):
|
||||
await self._websocket.close()
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
finally:
|
||||
self._websocket = None
|
||||
@@ -259,11 +262,13 @@ class AssemblyAISTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
break
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
|
||||
"""Parse a raw message into the appropriate message type."""
|
||||
@@ -292,7 +297,8 @@ class AssemblyAISTTService(STTService):
|
||||
elif isinstance(parsed_message, TerminationMessage):
|
||||
await self._handle_termination(parsed_message)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _handle_termination(self, message: TerminationMessage):
|
||||
"""Handle termination message."""
|
||||
|
||||
@@ -56,17 +56,6 @@ def language_to_async_language(language: Language) -> Optional[str]:
|
||||
Language.ES: "es",
|
||||
Language.DE: "de",
|
||||
Language.IT: "it",
|
||||
Language.PT: "pt",
|
||||
Language.NL: "nl",
|
||||
Language.AR: "ar",
|
||||
Language.RU: "ru",
|
||||
Language.RO: "ro",
|
||||
Language.JA: "ja",
|
||||
Language.HE: "he",
|
||||
Language.HY: "hy",
|
||||
Language.TR: "tr",
|
||||
Language.HI: "hi",
|
||||
Language.ZH: "zh",
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
@@ -85,7 +74,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
language: Language to use for synthesis.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -94,7 +83,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
voice_id: str,
|
||||
version: str = "v1",
|
||||
url: str = "wss://api.async.ai/text_to_speech/websocket/ws",
|
||||
model: str = "asyncflow_multilingual_v1.0",
|
||||
model: str = "asyncflow_v2.0",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "pcm_s16le",
|
||||
container: str = "raw",
|
||||
@@ -110,7 +99,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
https://docs.async.ai/list-voices-16699698e0
|
||||
version: Async API version.
|
||||
url: WebSocket URL for Async TTS API.
|
||||
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
|
||||
model: TTS model to use (e.g., "asyncflow_v2.0").
|
||||
sample_rate: Audio sample rate.
|
||||
encoding: Audio encoding format.
|
||||
container: Audio container format.
|
||||
@@ -139,7 +128,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else None,
|
||||
else "en",
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
@@ -239,7 +228,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -251,7 +241,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Async")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._started = False
|
||||
@@ -296,11 +287,12 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg.get("error_code"):
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
|
||||
else:
|
||||
await self.push_error(error_msg=f"Unknown message type: {msg}")
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic keepalive messages to maintain WebSocket connection."""
|
||||
@@ -343,14 +335,16 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AsyncAIHttpTTSService(TTSService):
|
||||
@@ -368,7 +362,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
language: Language to use for synthesis.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -376,7 +370,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
model: str = "asyncflow_multilingual_v1.0",
|
||||
model: str = "asyncflow_v2.0",
|
||||
url: str = "https://api.async.ai",
|
||||
version: str = "v1",
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -391,7 +385,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
api_key: Async API key.
|
||||
voice_id: ID of the voice to use for synthesis.
|
||||
aiohttp_session: An aiohttp session for making HTTP requests.
|
||||
model: TTS model to use (e.g., "asyncflow_multilingual_v1.0").
|
||||
model: TTS model to use (e.g., "asyncflow_v2.0").
|
||||
url: Base URL for Async API.
|
||||
version: API version string for Async API.
|
||||
sample_rate: Audio sample rate.
|
||||
@@ -415,7 +409,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else None,
|
||||
else "en",
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self.set_model_name(model)
|
||||
@@ -483,7 +477,8 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
await self.push_error(error_msg=f"Async API error: {error_text}")
|
||||
logger.error(f"Async API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(error=f"Async API error: {error_text}"))
|
||||
raise Exception(f"Async API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -499,7 +494,8 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -734,7 +734,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
aws_access_key: Optional[str] = None,
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: Optional[str] = None,
|
||||
aws_region: str = "us-east-1",
|
||||
params: Optional[InputParams] = None,
|
||||
client_config: Optional[Config] = None,
|
||||
retry_timeout_secs: Optional[float] = 5.0,
|
||||
@@ -1136,7 +1136,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
except (ReadTimeoutError, asyncio.TimeoutError):
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -453,7 +453,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._ready_to_send_context = True
|
||||
await self._finish_connecting_if_context_available()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Initialization error: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
@@ -577,7 +577,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.info("Finished disconnecting")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
def _create_client(self) -> BedrockRuntimeClient:
|
||||
config = Config(
|
||||
@@ -885,7 +885,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# Errors are kind of expected while disconnecting, so just
|
||||
# ignore them and do nothing
|
||||
return
|
||||
await self.push_error(error_msg=f"Error processing responses: {e}", exception=e)
|
||||
logger.error(f"{self} error processing responses: {e}")
|
||||
if self._wants_connection:
|
||||
await self.reset_conversation()
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
api_key: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
region: Optional[str] = "us-east-1",
|
||||
sample_rate: int = 16000,
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
@@ -69,7 +69,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
|
||||
aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
|
||||
aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
|
||||
region: AWS region for the service.
|
||||
region: AWS region for the service. Defaults to "us-east-1".
|
||||
sample_rate: Audio sample rate in Hz. Must be 8000 or 16000. Defaults to 16000.
|
||||
language: Language for transcription. Defaults to English.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
@@ -140,7 +140,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
@@ -181,7 +182,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
@@ -198,11 +200,13 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
@@ -285,7 +289,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
@@ -305,7 +310,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._ws_client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -523,15 +529,15 @@ class AWSTranscribeSTTService(STTService):
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
await self.push_error(error_msg=f"AWS Transcribe error: {error_msg}")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(ErrorFrame(f"AWS Transcribe error: {error_msg}"))
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
await self.push_error(
|
||||
error_msg=f"WebSocket connection closed in receive loop", exception=e
|
||||
)
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
break
|
||||
|
||||
@@ -312,6 +312,7 @@ class AWSPollyTTSService(TTSService):
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
logger.exception(f"{self} error generating TTS: {error}")
|
||||
error_message = f"AWS Polly TTS error: {str(error)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
logger.error(f"{self} error: image generation timed out")
|
||||
yield ErrorFrame("Image generation timed out")
|
||||
return
|
||||
|
||||
@@ -103,6 +104,7 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -61,5 +61,5 @@ class AzureRealtimeLLMService(OpenAIRealtimeLLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"initialization error: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
@@ -121,7 +121,8 @@ class AzureSTTService(STTService):
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech recognition service.
|
||||
@@ -150,9 +151,8 @@ class AzureSTTService(STTService):
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Uncaught exception during initialization: {e}", exception=e
|
||||
)
|
||||
logger.error(f"{self} exception during initialization: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech recognition service.
|
||||
|
||||
@@ -327,6 +327,7 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
@@ -354,13 +355,15 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
@@ -437,6 +440,5 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == CancellationReason.Error:
|
||||
yield ErrorFrame(
|
||||
error=f"Unknown error occurred: {cancellation_details.error_details}"
|
||||
)
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
yield ErrorFrame(error=f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
@@ -20,6 +20,7 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -275,7 +276,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
@@ -283,7 +285,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from Cartesia STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -316,7 +319,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
|
||||
elif data["type"] == "error":
|
||||
error_msg = data.get("message", "Unknown error")
|
||||
await self.push_error(error_msg=error_msg)
|
||||
logger.error(f"Cartesia error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(error=error_msg))
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
@@ -348,7 +352,6 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
@@ -361,6 +364,5 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -497,7 +497,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -509,7 +510,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
@@ -562,12 +564,13 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self.append_to_audio_context(msg["context_id"], frame)
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
self._context_id = None
|
||||
else:
|
||||
await self.push_error(error_msg=f"Error, unknown message type: {msg}")
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
@@ -605,14 +608,16 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class CartesiaHttpTTSService(TTSService):
|
||||
@@ -803,7 +808,8 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
yield ErrorFrame(error=f"Cartesia API error: {error_text}")
|
||||
logger.error(f"Cartesia API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(error=f"Cartesia API error: {error_text}"))
|
||||
raise Exception(f"Cartesia API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -819,7 +825,8 @@ class CartesiaHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -150,17 +150,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
params=params
|
||||
)
|
||||
"""
|
||||
# Note: For DeepgramFluxSTTService, differently from other processes, we need to create
|
||||
# the _receive_task inside _connect_websocket, because the websocket should only be
|
||||
# considered connected and ready to send audio once we receive from Flux the message
|
||||
# which confirms the connection has been established.
|
||||
# If we try to keep the logic reconnect_on_error, when receiving a message, the
|
||||
# _receive_task_handler would try to reconnect in case of error, invoking the
|
||||
# _connect_websocket again and leading to a case where the first _receive_task_handler
|
||||
# was never destroyed.
|
||||
# So we can keep it here as false, because inside the method send_with_retry, it will
|
||||
# already try to reconnect if needed.
|
||||
super().__init__(sample_rate=sample_rate, reconnect_on_error=False, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
@@ -193,6 +183,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
# Creating the receiver task (only created once during initial connection)
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
# Creating the watchdog task (only created once during initial connection)
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from WebSocket and clean up tasks.
|
||||
|
||||
@@ -202,7 +200,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
try:
|
||||
await self._disconnect_websocket()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Reset state only after everything is cleaned up
|
||||
self._websocket = None
|
||||
@@ -244,28 +243,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
# Creating the receiver task
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(
|
||||
self._receive_task_handler(self._report_error)
|
||||
)
|
||||
|
||||
# Creating the watchdog task
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
# Now wait for the connection established event
|
||||
logger.debug("WebSocket connected, waiting for server confirmation...")
|
||||
await self._connection_established_event.wait()
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -293,7 +278,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from Deepgram Flux Websocket")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -303,13 +289,10 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
|
||||
This signals to the server that no more audio data will be sent.
|
||||
"""
|
||||
try:
|
||||
if self._websocket:
|
||||
logger.debug("Sending CloseStream message to Deepgram Flux")
|
||||
message = {"type": "CloseStream"}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error sending closeStream: {e}", exception=e)
|
||||
if self._websocket:
|
||||
logger.debug("Sending CloseStream message to Deepgram Flux")
|
||||
message = {"type": "CloseStream"}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -396,13 +379,16 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
are issues sending the audio data.
|
||||
"""
|
||||
if not self._websocket:
|
||||
logger.error("Not connected to Deepgram Flux.")
|
||||
yield ErrorFrame("Not connected to Deepgram Flux.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._last_stt_time = time.monotonic()
|
||||
await self.send_with_retry(audio, self._report_error)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
yield None
|
||||
@@ -479,7 +465,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
# Skip malformed messages
|
||||
continue
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
# Error will be handled inside WebsocketService->_receive_task_handler
|
||||
raise
|
||||
else:
|
||||
|
||||
@@ -233,14 +233,7 @@ class DeepgramSTTService(STTService):
|
||||
)
|
||||
|
||||
if not await self._connection.start(options=self._settings, addons=self._addons):
|
||||
await self.push_error(error_msg=f"Unable to connect to Deepgram")
|
||||
else:
|
||||
headers = {
|
||||
k: v
|
||||
for k, v in self._connection._socket.response.headers.items()
|
||||
if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
logger.error(f"{self}: unable to connect to Deepgram")
|
||||
|
||||
async def _disconnect(self):
|
||||
if await self._connection.is_connected():
|
||||
@@ -263,7 +256,7 @@ class DeepgramSTTService(STTService):
|
||||
async def _on_error(self, *args, **kwargs):
|
||||
error: ErrorResponse = kwargs["error"]
|
||||
logger.warning(f"{self} connection error, will retry: {error}")
|
||||
await self.push_error(error_msg=f"{error}")
|
||||
await self.push_error(ErrorFrame(error=f"{error}"))
|
||||
await self.stop_all_metrics()
|
||||
# NOTE(aleix): we don't disconnect (i.e. call finish on the connection)
|
||||
# because this triggers more errors internally in the Deepgram SDK. So,
|
||||
|
||||
@@ -210,7 +210,8 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"Error sending audio to SageMaker: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"SageMaker STT error: {e}"))
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
@@ -259,7 +260,8 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"Failed to connect to SageMaker: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"SageMaker connection error: {e}"))
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -340,7 +342,8 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"Error processing responses: {e}", exc_info=True)
|
||||
await self.push_error(ErrorFrame(error=f"SageMaker response error: {e}"))
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
|
||||
@@ -10,45 +10,35 @@ This module provides integration with Deepgram's text-to-speech API
|
||||
for generating speech from text using various voice models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import TTSService, WebsocketTTSService
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
from deepgram import DeepgramClient, DeepgramClientOptions, SpeakOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramWebsocketTTSService, you need to `pip install pipecat-ai[deepgram]`."
|
||||
)
|
||||
logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramTTSService(WebsocketTTSService):
|
||||
"""Deepgram WebSocket-based text-to-speech service.
|
||||
class DeepgramTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service.
|
||||
|
||||
Provides real-time text-to-speech synthesis using Deepgram's WebSocket API.
|
||||
Supports streaming audio generation with interruption handling via the Clear
|
||||
message for conversational AI use cases.
|
||||
Provides text-to-speech synthesis using Deepgram's streaming API.
|
||||
Supports various voice models and audio encoding formats with
|
||||
configurable sample rates and quality settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -56,220 +46,42 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
base_url: str = "wss://api.deepgram.com",
|
||||
base_url: str = "",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram WebSocket TTS service.
|
||||
"""Initialize the Deepgram TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Deepgram API key for authentication.
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
base_url: WebSocket base URL for Deepgram API. Defaults to "wss://api.deepgram.com".
|
||||
base_url: Custom base URL for Deepgram API. Uses default if empty.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService class.
|
||||
**kwargs: Additional arguments passed to parent TTSService class.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._settings = {
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
|
||||
self._receive_task = None
|
||||
client_options = DeepgramClientOptions(url=base_url)
|
||||
self._deepgram_client = DeepgramClient(api_key, config=client_options)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram WebSocket TTS service supports metrics generation.
|
||||
True, as Deepgram TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# When the LLM finishes responding, flush any remaining text in Deepgram's buffer
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Deepgram WebSocket and start receive task."""
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Deepgram WebSocket and clean up tasks."""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Deepgram WebSocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to Deepgram WebSocket")
|
||||
|
||||
# Build WebSocket URL with query parameters
|
||||
params = []
|
||||
params.append(f"model={self._voice_id}")
|
||||
params.append(f"encoding={self._settings['encoding']}")
|
||||
params.append(f"sample_rate={self.sample_rate}")
|
||||
|
||||
url = f"{self._base_url}/v1/speak?{'&'.join(params)}"
|
||||
|
||||
headers = {"Authorization": f"Token {self._api_key}"}
|
||||
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close WebSocket connection and reset state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Deepgram WebSocket")
|
||||
# Send Close message to gracefully close the connection
|
||||
await self._websocket.send(json.dumps({"type": "Close"}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Send Clear message to stop current audio generation
|
||||
if self._websocket:
|
||||
try:
|
||||
clear_msg = {"type": "Clear"}
|
||||
await self._websocket.send(json.dumps(clear_msg))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process messages from Deepgram WebSocket."""
|
||||
async for message in self._get_websocket():
|
||||
if isinstance(message, bytes):
|
||||
# Binary message contains audio data
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(message, self.sample_rate, 1)
|
||||
await self.push_frame(frame)
|
||||
elif isinstance(message, str):
|
||||
# Text message contains metadata or control messages
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
msg_type = msg.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {msg}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {msg}")
|
||||
# Flushed indicates the end of audio generation for the current buffer
|
||||
# This happens after flush_audio() is called
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {msg}")
|
||||
# Buffer has been cleared after interruption
|
||||
# TTSStoppedFrame will be sent by the interruption handler
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: {msg.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {msg}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._websocket:
|
||||
try:
|
||||
flush_msg = {"type": "Flush"}
|
||||
await self._websocket.send(json.dumps(flush_msg))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's WebSocket TTS API.
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
@@ -279,27 +91,33 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self.sample_rate,
|
||||
container="none",
|
||||
)
|
||||
|
||||
try:
|
||||
# Reconnect if the websocket is closed
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
response = await self._deepgram_client.speak.asyncrest.v("1").stream_raw(
|
||||
{"text": text}, options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Send text message to Deepgram
|
||||
# Note: We don't send Flush here - that should only be sent when the
|
||||
# LLM finishes a complete response via flush_audio()
|
||||
speak_msg = {"type": "Speak", "text": text}
|
||||
await self._get_websocket().send(json.dumps(speak_msg))
|
||||
async for data in response.aiter_bytes():
|
||||
await self.stop_ttfb_metrics()
|
||||
if data:
|
||||
yield TTSAudioRawFrame(audio=data, sample_rate=self.sample_rate, num_channels=1)
|
||||
|
||||
# The audio frames will be handled in _receive_messages
|
||||
yield None
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class DeepgramHttpTTSService(TTSService):
|
||||
@@ -409,4 +227,5 @@ class DeepgramHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
|
||||
@@ -351,7 +351,8 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
def audio_format_from_sample_rate(sample_rate: int) -> str:
|
||||
@@ -597,6 +598,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"ElevenLabs Realtime STT error: {str(e)}")
|
||||
|
||||
yield None
|
||||
@@ -661,9 +663,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug("Connected to ElevenLabs Realtime STT")
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to ElevenLabs Realtime STT: {e}", exception=e
|
||||
)
|
||||
logger.error(f"{self}: unable to connect to ElevenLabs Realtime STT: {e}")
|
||||
await self.push_error(ErrorFrame(f"Connection error: {str(e)}"))
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Disconnect from ElevenLabs Realtime STT WebSocket."""
|
||||
@@ -672,7 +673,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from ElevenLabs Realtime STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -732,17 +733,17 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
elif message_type == "error":
|
||||
error_msg = data.get("error", "Unknown error")
|
||||
logger.error(f"ElevenLabs error: {error_msg}")
|
||||
await self.push_error(error_msg=f"Error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Error: {error_msg}"))
|
||||
|
||||
elif message_type == "auth_error":
|
||||
error_msg = data.get("error", "Authentication error")
|
||||
logger.error(f"ElevenLabs auth error: {error_msg}")
|
||||
await self.push_error(error_msg=f"Auth error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Auth error: {error_msg}"))
|
||||
|
||||
elif message_type == "quota_exceeded_error":
|
||||
error_msg = data.get("error", "Quota exceeded")
|
||||
logger.error(f"ElevenLabs quota exceeded: {error_msg}")
|
||||
await self.push_error(error_msg=f"Quota exceeded: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Quota exceeded: {error_msg}"))
|
||||
|
||||
else:
|
||||
logger.debug(f"Unknown message type: {message_type}")
|
||||
|
||||
@@ -160,7 +160,7 @@ def build_elevenlabs_voice_settings(
|
||||
class PronunciationDictionaryLocator(BaseModel):
|
||||
"""Locator for a pronunciation dictionary.
|
||||
|
||||
Parameters:
|
||||
Attributes:
|
||||
pronunciation_dictionary_id: The ID of the pronunciation dictionary.
|
||||
version_id: The version ID of the pronunciation dictionary.
|
||||
"""
|
||||
@@ -424,7 +424,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
@@ -535,8 +536,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
self._websocket = None
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
@@ -551,7 +553,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from ElevenLabs")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
@@ -581,7 +584,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
self._partial_word = ""
|
||||
@@ -731,16 +735,20 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
else:
|
||||
await self._send_text(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
self._started = False
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class ElevenLabsHttpTTSService(WordTTSService):
|
||||
@@ -1035,6 +1043,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"{self} error: {error_text}")
|
||||
yield ErrorFrame(error=f"ElevenLabs API error: {error_text}")
|
||||
return
|
||||
|
||||
@@ -1082,7 +1091,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
logger.warning(f"Failed to parse JSON from stream: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
continue
|
||||
|
||||
# After processing all chunks, emit any remaining partial word
|
||||
@@ -1106,7 +1116,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
self._previous_text = text
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame
|
||||
|
||||
@@ -110,6 +110,7 @@ class FalImageGenService(ImageGenService):
|
||||
image_url = response["images"][0]["url"] if response else None
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -290,4 +290,5 @@ class FalSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -76,7 +76,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
api_key: str,
|
||||
reference_id: Optional[str] = None, # This is the voice ID
|
||||
model: Optional[str] = None, # Deprecated
|
||||
model_id: str = "s1",
|
||||
model_id: str = "speech-1.5",
|
||||
output_format: FishAudioOutputFormat = "pcm",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -93,7 +93,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
The `model` parameter is deprecated and will be removed in version 0.1.0.
|
||||
Use `reference_id` instead to specify the voice model.
|
||||
|
||||
model_id: Specify which Fish Audio TTS model to use (e.g. "s1")
|
||||
model_id: Specify which Fish Audio TTS model to use (e.g. "speech-1.5")
|
||||
output_format: Audio output format. Defaults to "pcm".
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
@@ -228,7 +228,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -242,7 +243,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
await self._websocket.send(ormsgpack.packb(stop_message))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
@@ -284,7 +286,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -320,7 +323,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
flush_message = {"event": "flush"}
|
||||
await self._get_websocket().send(ormsgpack.packb(flush_message))
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -328,4 +332,5 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -468,7 +468,8 @@ class GladiaSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._connection_active = False
|
||||
|
||||
if not self._should_reconnect:
|
||||
@@ -558,7 +559,8 @@ class GladiaSTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.debug("Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
@@ -621,7 +623,8 @@ class GladiaSTTService(STTService):
|
||||
# Expected when closing the connection
|
||||
pass
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _maybe_reconnect(self) -> bool:
|
||||
"""Handle exponential backoff reconnection logic."""
|
||||
@@ -629,9 +632,7 @@ class GladiaSTTService(STTService):
|
||||
return False
|
||||
self._reconnection_attempts += 1
|
||||
if self._reconnection_attempts > self._max_reconnection_attempts:
|
||||
await self.push_error(
|
||||
error_msg=f"Max reconnection attempts ({self._max_reconnection_attempts}) reached",
|
||||
)
|
||||
logger.error(f"Max reconnection attempts ({self._max_reconnection_attempts}) reached")
|
||||
self._should_reconnect = False
|
||||
return False
|
||||
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
|
||||
|
||||
@@ -1175,7 +1175,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._connection_task = self.create_task(self._connection_task_handler(config=config))
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Initialization error: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}"))
|
||||
|
||||
async def _connection_task_handler(self, config: LiveConnectConfig):
|
||||
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
|
||||
@@ -1252,11 +1252,11 @@ class GeminiLiveLLMService(LLMService):
|
||||
)
|
||||
|
||||
if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
error_msg = (
|
||||
logger.error(
|
||||
f"Max consecutive failures ({MAX_CONSECUTIVE_FAILURES}) reached, "
|
||||
"treating as fatal error"
|
||||
)
|
||||
await self.push_error(error_msg=error_msg, exception=error)
|
||||
await self.push_error(ErrorFrame(error=f"{self} Error in receive loop: {error}"))
|
||||
return False
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1284,7 +1284,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._completed_tool_calls = set()
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
"""Send user audio frame to Gemini Live API."""
|
||||
@@ -1723,8 +1723,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_input_tokens=usage.cached_content_token_count,
|
||||
reasoning_tokens=usage.thoughts_token_count,
|
||||
)
|
||||
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
@@ -1745,7 +1743,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
# state management, and that exponential backoff for retries can have
|
||||
# cost/stability implications for a service cluster, let's just treat a
|
||||
# send-side error as fatal.
|
||||
await self.push_error(error_msg=f"Send error: {error}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}", fatal=True))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
|
||||
@@ -110,6 +110,7 @@ class GoogleImageGenService(ImageGenService):
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not response or not response.generated_images:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
@@ -127,4 +128,5 @@ class GoogleImageGenService(ImageGenService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating image: {e}")
|
||||
yield ErrorFrame(f"Image generation error: {str(e)}")
|
||||
|
||||
@@ -793,7 +793,7 @@ class GoogleLLMService(LLMService):
|
||||
return
|
||||
generation_params.setdefault("thinking_config", {})["thinking_budget"] = 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unset thinking budget: {e}")
|
||||
logger.exception(f"Failed to unset thinking budget: {e}")
|
||||
|
||||
async def _stream_content(
|
||||
self, params_from_context: GeminiLLMInvocationParams
|
||||
@@ -983,7 +983,7 @@ class GoogleLLMService(LLMService):
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
if grounding_metadata and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
|
||||
@@ -774,7 +774,8 @@ class GoogleSTTService(STTService):
|
||||
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _stream_audio(self):
|
||||
@@ -805,13 +806,15 @@ class GoogleSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process an audio chunk for STT transcription.
|
||||
@@ -899,7 +902,8 @@ class GoogleSTTService(STTService):
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a
|
||||
# timeout, propagate to _stream_audio to reconnect)
|
||||
raise
|
||||
|
||||
@@ -737,6 +737,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -995,7 +996,9 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"TTS generation error: {str(e)}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
class GeminiTTSService(GoogleBaseTTSService):
|
||||
@@ -1245,5 +1248,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"Gemini TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
@@ -1,239 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gradium's speech-to-text service implementation.
|
||||
|
||||
This module provides integration with Gradium's real-time speech-to-text
|
||||
WebSocket API for streaming audio transcription.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
"""Gradium real-time speech-to-text service.
|
||||
|
||||
Provides real-time speech transcription using Gradium's WebSocket API.
|
||||
Supports both interim and final transcriptions with configurable parameters
|
||||
for audio processing and connection management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
json_config: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._json_config = json_config
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
self._audio_buffer = bytearray()
|
||||
self._chunk_size_ms = 80
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True if metrics generation is supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Start frame to begin processing.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: End frame to stop processing.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Cancel frame to abort processing.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text conversion.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to process.
|
||||
|
||||
Yields:
|
||||
None (processing handled via WebSocket messages).
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
chunk = base64.b64encode(chunk).decode("utf-8")
|
||||
msg = {"type": "audio", "audio": chunk}
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
yield None
|
||||
|
||||
@traced_stt
|
||||
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
|
||||
"""Record transcription event for tracing."""
|
||||
pass
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
ws_url = self._api_endpoint_base_url
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
"x-api-source": "pipecat",
|
||||
}
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"input_format": "pcm",
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Gradium STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
await self._process_messages()
|
||||
logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting")
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _process_response(self, msg):
|
||||
type_ = msg.get("type", "")
|
||||
if type_ == "text":
|
||||
await self._handle_text(msg["text"])
|
||||
elif type_ == "end_of_stream":
|
||||
await self._handle_end_of_stream()
|
||||
elif type_ == "error":
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
|
||||
async def _handle_end_of_stream(self):
|
||||
"""Handle termination message."""
|
||||
logger.debug("Received end_of_stream message from server")
|
||||
|
||||
async def _handle_text(self, text: str):
|
||||
"""Handle transcription results."""
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
@@ -1,315 +0,0 @@
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
|
||||
"""Gradium Text-to-Speech service implementation."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets import ConnectionClosedOK
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Gradium, you need to `pip install pipecat-ai[gradium]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 48000
|
||||
|
||||
|
||||
class GradiumTTSService(InterruptibleWordTTSService):
|
||||
"""Text-to-Speech service using Gradium's websocket API."""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Gradium TTS service.
|
||||
|
||||
Parameters:
|
||||
temp: Temperature to be used for generation, defaults to 0.6.
|
||||
"""
|
||||
|
||||
temp: Optional[float] = 0.6
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "YTpq7expH9539ERJ",
|
||||
url: str = "wss://eu.api.gradium.ai/api/speech/tts",
|
||||
model: str = "default",
|
||||
json_config: Optional[str] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
voice_id: the voice identifier.
|
||||
url: Gradium websocket API endpoint.
|
||||
model: Model ID to use for synthesis.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
params: Additional configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = params or GradiumTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._json_config = json_config
|
||||
self._model = model
|
||||
self._settings = {
|
||||
"voice_id": voice_id,
|
||||
"model_name": model,
|
||||
"output_format": "pcm",
|
||||
}
|
||||
|
||||
# State tracking
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Gradium service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Update the TTS model.
|
||||
|
||||
Args:
|
||||
model: The model name to use for synthesis.
|
||||
"""
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
self._settings["voice_id"] = self._voice_id
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Gradium API."""
|
||||
return {"text": text, "type": "text"}
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish websocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close connection.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel current operation and clean up.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish websocket connection and start receive task."""
|
||||
logger.debug(f"{self}: connecting")
|
||||
|
||||
# If the server disconnected, cancel the receive-task so that it can be reset below.
|
||||
if self._websocket is None or self._websocket.state is not State.OPEN:
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
logger.debug(f"{self}: setting receive task")
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
logger.debug(f"{self}: disconnecting")
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Gradium websocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
headers = {"x-api-key": self._api_key, "x-api-source": "pipecat"}
|
||||
self._websocket = await websocket_connect(self._url, additional_headers=headers)
|
||||
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"output_format": "pcm",
|
||||
"voice_id": self._voice_id,
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close websocket connection and reset state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
# TODO(laurent): This should not be necessary as it should happen when
|
||||
# receiving the messages but this does not seem to always be the case
|
||||
# and that may lead to a busy polling loop.
|
||||
if self._websocket and self._websocket.state is State.CLOSED:
|
||||
raise ConnectionClosedOK(None, None)
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
if msg["type"] == "audio":
|
||||
# Process audio chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
elif msg["type"] == "text":
|
||||
await self.add_word_timestamps([(msg["text"], msg["start_s"])])
|
||||
elif msg["type"] == "end_of_stream":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
|
||||
elif msg["type"] == "error":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Gradium's streaming API.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
_state = self._websocket.state if self._websocket is not None else None
|
||||
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
self._websocket = None
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
yield TTSStartedFrame()
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -123,8 +123,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cache_read_input_tokens = None
|
||||
self._reasoning_tokens = None
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
@@ -139,8 +137,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
cache_read_input_tokens=self._cache_read_input_tokens,
|
||||
reasoning_tokens=self._reasoning_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
@@ -153,7 +149,7 @@ class GrokLLMService(OpenAILLMService):
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens, completion_tokens, and optional cached/reasoning tokens.
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
@@ -168,13 +164,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
|
||||
# Capture cached & reasoning tokens (these typically only appear once per request)
|
||||
if tokens.cache_read_input_tokens is not None:
|
||||
self._cache_read_input_tokens = tokens.cache_read_input_tokens
|
||||
|
||||
if tokens.reasoning_tokens is not None:
|
||||
self._reasoning_tokens = tokens.reasoning_tokens
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
|
||||
@@ -146,6 +146,7 @@ class GroqTTSService(TTSService):
|
||||
bytes = w.readframes(num_frames)
|
||||
yield TTSAudioRawFrame(bytes, frame_rate, channels)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -179,7 +179,7 @@ class HeyGenClient:
|
||||
await self._task_manager.cancel_task(self._event_task)
|
||||
self._event_task = None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
logger.exception(f"Exception during cleanup: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame, audio_chunk_size: int) -> None:
|
||||
"""Start the client and establish all necessary connections.
|
||||
|
||||
@@ -8,14 +8,10 @@ import base64
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat import __version__
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
@@ -30,7 +26,11 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from hume import AsyncHumeClient
|
||||
from hume.tts import FormatPcm, PostedUtterance, PostedUtteranceVoiceWithId
|
||||
from hume.tts import (
|
||||
FormatPcm,
|
||||
PostedUtterance,
|
||||
PostedUtteranceVoiceWithId,
|
||||
)
|
||||
from hume.tts.types import TimestampMessage
|
||||
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -40,12 +40,6 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
|
||||
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
|
||||
|
||||
# Tracking headers for Hume API requests
|
||||
DEFAULT_HEADERS = {
|
||||
"X-Hume-Client-Name": "pipecat",
|
||||
"X-Hume-Client-Version": __version__,
|
||||
}
|
||||
|
||||
|
||||
class HumeTTSService(WordTTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
@@ -110,11 +104,7 @@ class HumeTTSService(WordTTSService):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create a custom httpx.AsyncClient with tracking headers
|
||||
# Headers are included in all requests made by the Hume SDK
|
||||
self._http_client = httpx.AsyncClient(headers=DEFAULT_HEADERS)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key, httpx_client=self._http_client)
|
||||
self._client = AsyncHumeClient(api_key=api_key)
|
||||
self._params = params or HumeTTSService.InputParams()
|
||||
|
||||
# Store voice in the base class (mirrors other services)
|
||||
@@ -148,26 +138,6 @@ class HumeTTSService(WordTTSService):
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
"""Stop the service and cleanup resources.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if hasattr(self, "_http_client") and self._http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def cancel(self, frame: CancelFrame) -> None:
|
||||
"""Cancel the service and cleanup resources.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
if hasattr(self, "_http_client") and self._http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
@@ -317,7 +287,8 @@ class HumeTTSService(WordTTSService):
|
||||
self._cumulative_time = utterance_duration
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Ensure TTFB timer is stopped even on early failures
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -397,7 +397,8 @@ class InworldTTSService(TTSService):
|
||||
# STEP 7: ERROR HANDLING
|
||||
# ================================================================================
|
||||
# Log any unexpected errors and notify the pipeline
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# ================================================================================
|
||||
# STEP 8: CLEANUP AND COMPLETION
|
||||
@@ -512,7 +513,7 @@ class InworldTTSService(TTSService):
|
||||
# Extract the base64-encoded audio content from response
|
||||
if "audioContent" not in response_data:
|
||||
logger.error("No audioContent in Inworld API response")
|
||||
yield ErrorFrame(error="No audioContent in response")
|
||||
await self.push_error(ErrorFrame(error="No audioContent in response"))
|
||||
return
|
||||
|
||||
# ================================================================================
|
||||
|
||||
@@ -173,17 +173,16 @@ class LLMService(AIService):
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
self._skip_tts: Optional[bool] = None
|
||||
self._skip_tts: bool = False
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -294,8 +293,7 @@ class LLMService(AIService):
|
||||
direction: The direction of frame pushing.
|
||||
"""
|
||||
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
if self._skip_tts is not None:
|
||||
frame.skip_tts = self._skip_tts
|
||||
frame.skip_tts = self._skip_tts
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@@ -437,7 +435,6 @@ class LLMService(AIService):
|
||||
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
|
||||
|
||||
runner_items = []
|
||||
for function_call in function_calls:
|
||||
if function_call.function_name in self._functions.keys():
|
||||
item = self._functions[function_call.function_name]
|
||||
@@ -449,20 +446,28 @@ class LLMService(AIService):
|
||||
)
|
||||
continue
|
||||
|
||||
runner_items.append(
|
||||
FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
)
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
)
|
||||
|
||||
if self._run_in_parallel:
|
||||
await self._run_parallel_function_calls(runner_items)
|
||||
else:
|
||||
await self._run_sequential_function_calls(runner_items)
|
||||
if self._run_in_parallel:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
else:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
|
||||
async def request_image_frame(
|
||||
self,
|
||||
@@ -535,27 +540,6 @@ class LLMService(AIService):
|
||||
await task
|
||||
del self._function_call_tasks[task]
|
||||
|
||||
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
tasks = []
|
||||
for runner_item in runner_items:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
tasks.append(task)
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
|
||||
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
# Enqueue all function calls for background execution.
|
||||
for runner_item in runner_items:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
@@ -639,19 +623,20 @@ class LLMService(AIService):
|
||||
name = runner_item.function_name
|
||||
tool_call_id = runner_item.tool_call_id
|
||||
|
||||
# We remove the callback because we are going to cancel the task
|
||||
# now, otherwise we will be removing it from the set while we
|
||||
# are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
|
||||
if task:
|
||||
# We remove the callback because we are going to cancel the
|
||||
# task next, otherwise we will be removing it from the set
|
||||
# while we are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
await self.cancel_task(task)
|
||||
cancelled_tasks.add(task)
|
||||
await self.cancel_task(task)
|
||||
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
|
||||
@@ -214,7 +214,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -230,7 +231,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
# await self._websocket.send(json.dumps({"eof": True}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -264,9 +266,10 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
if "error" in msg:
|
||||
logger.error(f"{self} error: {msg['error']}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['error']}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
@@ -299,11 +302,13 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(json.dumps({"flush": True}))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeAlias
|
||||
from typing import Any, Dict, List, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -46,24 +46,17 @@ class MCPClient(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
tools_filter: Optional[List[str]] = None,
|
||||
tools_output_filters: Optional[Dict[str, Callable[[Any], Any]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
Args:
|
||||
server_params: Server connection parameters (stdio or SSE).
|
||||
tools_filter: Optional list of tool names to register. If None, all tools are registered.
|
||||
tools_output_filters: Optional dict mapping tool names to filter functions that process tool outputs.
|
||||
Each filter function receives the raw tool output (any type) and returns the processed output (any type).
|
||||
**kwargs: Additional arguments passed to the parent BaseObject.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._server_params = server_params
|
||||
self._session = ClientSession
|
||||
self._tools_filter = tools_filter
|
||||
self._tools_output_filters = tools_output_filters or {}
|
||||
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
self._client = stdio_client
|
||||
@@ -183,6 +176,7 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _stdio_list_tools(self) -> ToolsSchema:
|
||||
@@ -213,6 +207,7 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _streamable_http_list_tools(self) -> ToolsSchema:
|
||||
@@ -251,6 +246,7 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _call_tool(self, session, function_name, arguments, result_callback):
|
||||
@@ -271,26 +267,13 @@ class MCPClient(BaseObject):
|
||||
else:
|
||||
# logger.debug(f"Non-text result content: '{content}'")
|
||||
pass
|
||||
logger.info(f"Tool '{function_name}' completed successfully")
|
||||
logger.debug(f"Final response: {response}")
|
||||
else:
|
||||
logger.error(f"Error getting content from {function_name} results.")
|
||||
|
||||
# Apply output filter if configured for this tool
|
||||
if function_name in self._tools_output_filters:
|
||||
try:
|
||||
response = self._tools_output_filters[function_name](response)
|
||||
logger.debug(f"Final response (after filter): {response}")
|
||||
|
||||
except Exception:
|
||||
logger.error(f"Error applying output filter for {function_name}")
|
||||
response = ""
|
||||
|
||||
if response and len(response) and isinstance(response, str):
|
||||
logger.info(f"Tool '{function_name}' completed successfully")
|
||||
logger.debug(f"Final response: {response}")
|
||||
else:
|
||||
response = "Sorry, could not call the mcp tool"
|
||||
|
||||
await result_callback(response)
|
||||
final_response = response if len(response) else "Sorry, could not call the mcp tool"
|
||||
await result_callback(final_response)
|
||||
|
||||
async def _list_tools_helper(self, session):
|
||||
available_tools = await session.list_tools()
|
||||
@@ -303,12 +286,6 @@ class MCPClient(BaseObject):
|
||||
|
||||
for tool in available_tools.tools:
|
||||
tool_name = tool.name
|
||||
|
||||
# Apply tools filter if configured
|
||||
if self._tools_filter and tool_name not in self._tools_filter:
|
||||
logger.debug(f"Skipping tool '{tool_name}' - not in allowed tools list")
|
||||
continue
|
||||
|
||||
logger.debug(f"Processing tool: {tool_name}")
|
||||
logger.debug(f"Tool description: {tool.description}")
|
||||
|
||||
@@ -325,6 +302,7 @@ class MCPClient(BaseObject):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read tool '{tool_name}': {str(e)}")
|
||||
logger.exception("Full exception details:")
|
||||
continue
|
||||
|
||||
logger.debug(f"Completed reading {len(tool_schemas)} tools")
|
||||
|
||||
@@ -253,9 +253,8 @@ class Mem0MemoryService(FrameProcessor):
|
||||
# Otherwise, pass the enhanced context frame downstream
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Error processing with Mem0: {str(e)}", exception=e
|
||||
)
|
||||
logger.error(f"Error processing with Mem0: {str(e)}")
|
||||
await self.push_frame(ErrorFrame(f"Error processing with Mem0: {str(e)}"))
|
||||
await self.push_frame(frame) # Still pass the original frame through
|
||||
else:
|
||||
# For non-context frames, just pass them through
|
||||
|
||||
@@ -314,6 +314,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"MiniMax TTS error: HTTP {response.status}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
@@ -391,7 +392,8 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -110,6 +110,7 @@ class MoondreamService(VisionService):
|
||||
if analysis fails.
|
||||
"""
|
||||
if not self._model:
|
||||
logger.error(f"{self} error: Moondream model not available ({self.model_name})")
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
return
|
||||
|
||||
|
||||
@@ -285,7 +285,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -298,7 +299,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Neuphonic")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -363,14 +365,16 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class NeuphonicHttpTTSService(TTSService):
|
||||
@@ -534,6 +538,7 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
error_message = f"Neuphonic API error: HTTP {response.status} - {error_text}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
@@ -563,7 +568,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
yield TTSAudioRawFrame(audio_bytes, self.sample_rate, 1)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
# Don't yield error frame for individual message failures
|
||||
continue
|
||||
|
||||
@@ -571,7 +577,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
logger.debug("TTS generation cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -8,23 +8,98 @@
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaLLMService from
|
||||
pipecat.services.nvidia.llm instead.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"NimLLMService from pipecat.services.nim.llm is deprecated. "
|
||||
"Please use NvidiaLLMService from pipecat.services.nvidia.llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
class NimLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
NimLLMService = NvidiaLLMService
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NimLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA NIM API service implementation.
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
"""
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class NvidiaLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NvidiaLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
@@ -1,663 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_nvidia_riva_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to NVIDIA Riva ASR language codes.
|
||||
|
||||
Source:
|
||||
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
# Arabic
|
||||
Language.AR: "ar-AR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_GB: "en-GB",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Portuguese
|
||||
Language.PT: "pt-BR", # Default to Brazilian
|
||||
Language.PT_BR: "pt-BR",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US", # US Spanish
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class NvidiaSTTService(STTService):
|
||||
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
|
||||
|
||||
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
|
||||
through streaming recognition. Supports interim results and continuous audio
|
||||
processing for low-latency applications.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: NVIDIA Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for NVIDIA Riva.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._settings = {
|
||||
"language": str(params.language),
|
||||
"profanity_filter": self._profanity_filter,
|
||||
"automatic_punctuation": self._automatic_punctuation,
|
||||
"verbatim_transcripts": not self._no_verbatim_transcripts,
|
||||
"boosted_lm_words": self._boosted_lm_words,
|
||||
"boosted_lm_score": self._boosted_lm_score,
|
||||
}
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
False - this service does not support metrics generation.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the NVIDIA Riva STT service and initialize streaming configuration.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the NVIDIA Riva STT service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the NVIDIA Riva STT service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._thread_task:
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
self._thread_running = True
|
||||
await asyncio.to_thread(self._response_handler)
|
||||
except asyncio.CancelledError:
|
||||
self._thread_running = False
|
||||
raise
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_response(self, response):
|
||||
for result in response.results:
|
||||
if result and not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=result.is_final,
|
||||
language=self._language_code,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
"""Get the next audio chunk for NVIDIA Riva processing.
|
||||
|
||||
Returns:
|
||||
Audio bytes from the queue.
|
||||
|
||||
Raises:
|
||||
StopIteration: When the thread is no longer running.
|
||||
"""
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
except FuturesCancelledError:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator for audio chunk processing.
|
||||
|
||||
Returns:
|
||||
Self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
|
||||
|
||||
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
|
||||
transcription on audio segments. It inherits from SegmentedSTTService to handle
|
||||
audio buffering and speech detection.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva segmented STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
profanity_filter: Whether to filter profanity from results.
|
||||
automatic_punctuation: Whether to add automatic punctuation.
|
||||
verbatim_transcripts: Whether to return verbatim transcripts.
|
||||
boosted_lm_words: List of words to boost in language model.
|
||||
boosted_lm_score: Score boost for specified words.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
profanity_filter: bool = False
|
||||
automatic_punctuation: bool = True
|
||||
verbatim_transcripts: bool = False
|
||||
boosted_lm_words: Optional[List[str]] = None
|
||||
boosted_lm_score: float = 4.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
"model_name": "canary-1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva segmented STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication
|
||||
server: NVIDIA Riva server address (defaults to NVIDIA Cloud Function endpoint)
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for NVIDIA Riva
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSegmentedSTTService.InputParams()
|
||||
|
||||
# Set model name
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
# Initialize NVIDIA Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
# Store the language as a Language enum and as a string
|
||||
self._language_enum = params.language or Language.EN_US
|
||||
self._language = self.language_to_service_language(self._language_enum) or "en-US"
|
||||
|
||||
# Configure transcription parameters
|
||||
self._profanity_filter = params.profanity_filter
|
||||
self._automatic_punctuation = params.automatic_punctuation
|
||||
self._verbatim_transcripts = params.verbatim_transcripts
|
||||
self._boosted_lm_words = params.boosted_lm_words
|
||||
self._boosted_lm_score = params.boosted_lm_score
|
||||
|
||||
# Voice activity detection thresholds (use NVIDIA Riva defaults)
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
# Create NVIDIA Riva client
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = {"language": self._language_enum}
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to NVIDIA Riva's language code.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
return language_to_nvidia_riva_language(language)
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the NVIDIA Riva ASR client with authentication metadata."""
|
||||
if self._asr_service is not None:
|
||||
return
|
||||
|
||||
# Set up authentication metadata for NVIDIA Cloud Functions
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=self._language, # Now using the string, not a tuple
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=self._verbatim_transcripts,
|
||||
)
|
||||
|
||||
# Add word boosting if specified
|
||||
if self._boosted_lm_words:
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
# Add voice activity detection parameters
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
|
||||
# Add any custom configuration
|
||||
if self._custom_configuration:
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True - this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the service when the pipeline starts.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
|
||||
Args:
|
||||
language: Target language for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_enum = language
|
||||
self._language = self.language_to_service_language(language) or "en-US"
|
||||
self._settings["language"] = language
|
||||
|
||||
# Update configuration with new language
|
||||
if self._config:
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in NVIDIA Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
|
||||
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
@@ -1,187 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva text-to-speech service implementation.
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
NVIDIA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class NvidiaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language code for synthesis. Defaults to US English.
|
||||
quality: Audio quality setting (0-100). Defaults to 20.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
Note: Model cannot be changed after initialization for Riva service.
|
||||
|
||||
Args:
|
||||
model: The model name to set (operation not supported).
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
@@ -133,7 +133,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
self._retry_timeout_secs = retry_timeout_secs
|
||||
self._retry_on_timeout = retry_on_timeout
|
||||
self.set_model_name(model)
|
||||
self._full_model_name: str = ""
|
||||
self._client = self.create_client(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
@@ -186,22 +185,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
"""
|
||||
return True
|
||||
|
||||
def set_full_model_name(self, full_model_name: str):
|
||||
"""Set the full AI model name.
|
||||
|
||||
Args:
|
||||
full_model_name: The full name of the AI model to use.
|
||||
"""
|
||||
self._full_model_name = full_model_name
|
||||
|
||||
def get_full_model_name(self):
|
||||
"""Get the current full model name.
|
||||
|
||||
Returns:
|
||||
The full name of the AI model being used.
|
||||
"""
|
||||
return self._full_model_name
|
||||
|
||||
async def get_chat_completions(
|
||||
self, params_from_context: OpenAILLMInvocationParams
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
@@ -363,23 +346,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else None
|
||||
)
|
||||
reasoning_tokens = (
|
||||
chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
if chunk.usage.completion_tokens_details
|
||||
else None
|
||||
)
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
cache_read_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.model and self.get_full_model_name() != chunk.model:
|
||||
self.set_full_model_name(chunk.model)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ class OpenAIImageGenService(ImageGenService):
|
||||
image_url = image.data[0].url
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} No image provided in response: {image}")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt
|
||||
@@ -443,7 +444,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -460,7 +461,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
self._completed_tool_calls = set()
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
async def _ws_send(self, realtime_message):
|
||||
try:
|
||||
@@ -473,11 +474,12 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# somehow *started* the websocket send attempt while we still
|
||||
# had a connection)
|
||||
return
|
||||
logger.error(f"Error sending message to websocket: {e}")
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
@@ -655,17 +657,10 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
async def _handle_evt_response_done(self, evt):
|
||||
# todo: figure out whether there's anything we need to do for "cancelled" events
|
||||
# usage metrics
|
||||
cached_tokens = (
|
||||
evt.response.usage.input_token_details.cached_tokens
|
||||
if hasattr(evt.response.usage, "input_token_details")
|
||||
and evt.response.usage.input_token_details
|
||||
else None
|
||||
)
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=evt.response.usage.input_tokens,
|
||||
completion_tokens=evt.response.usage.output_tokens,
|
||||
total_tokens=evt.response.usage.total_tokens,
|
||||
cache_read_input_tokens=cached_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
await self.stop_processing_metrics()
|
||||
@@ -673,7 +668,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
self._current_assistant_response = None
|
||||
# error handling
|
||||
if evt.response.status == "failed":
|
||||
await self.push_error(error_msg=evt.response.status_details["error"]["message"])
|
||||
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
|
||||
return
|
||||
# response content
|
||||
for item in evt.response.output:
|
||||
@@ -765,7 +760,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(error_msg=f"Error: {evt}")
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
|
||||
|
||||
#
|
||||
# state and client events for the current conversation
|
||||
@@ -815,7 +810,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# We're done configuring the LLM for this session
|
||||
self._llm_needs_conversation_setup = False
|
||||
|
||||
logger.debug("Creating response")
|
||||
logger.debug(f"Creating response")
|
||||
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
@@ -206,4 +206,5 @@ class OpenAITTSService(TTSService):
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -79,5 +79,5 @@ class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
@@ -425,7 +425,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -441,7 +441,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
self._receive_task = None
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
async def _ws_send(self, realtime_message):
|
||||
try:
|
||||
@@ -450,11 +450,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
except Exception as e:
|
||||
if self._disconnecting:
|
||||
return
|
||||
logger.error(f"Error sending message to websocket: {e}")
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
@@ -685,7 +686,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(error_msg=f"Error: {evt}")
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
|
||||
|
||||
async def _handle_assistant_output(self, output):
|
||||
# We haven't seen intermixed audio and function_call items in the same response. But let's
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user