Compare commits
129 Commits
fix/speech
...
fix/self-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b873198a59 | ||
|
|
5b696bd4ae | ||
|
|
b67af19d47 | ||
|
|
6d9c07b945 | ||
|
|
18429f80f1 | ||
|
|
0a54dc9721 | ||
|
|
521f669051 | ||
|
|
abb20f34ba | ||
|
|
b1e72ad4b7 | ||
|
|
f610fb95f9 | ||
|
|
827032fefb | ||
|
|
af4ef95dc6 | ||
|
|
0370bb15e4 | ||
|
|
2b3595485f | ||
|
|
63c664becb | ||
|
|
fecf462139 | ||
|
|
023063759a | ||
|
|
c49eda98e7 | ||
|
|
5d07326e36 | ||
|
|
fa659311b6 | ||
|
|
125c423356 | ||
|
|
c9615c8db6 | ||
|
|
28c542f6ed | ||
|
|
5708c81b93 | ||
|
|
82ce3ea8de | ||
|
|
62ada92188 | ||
|
|
273692421f | ||
|
|
0a3e212f93 | ||
|
|
43d686c622 | ||
|
|
4d136e1e28 | ||
|
|
2024285c75 | ||
|
|
bc830c16f1 | ||
|
|
18630c9478 | ||
|
|
3a8d3cc841 | ||
|
|
2963c7589d | ||
|
|
63caa403cb | ||
|
|
846cf0794d | ||
|
|
498349c17e | ||
|
|
474b27305f | ||
|
|
20509e8f96 | ||
|
|
5b2fa69bdc | ||
|
|
4f8cacc769 | ||
|
|
0145fb4ea0 | ||
|
|
8e52df7f03 | ||
|
|
8ee99e37ff | ||
|
|
bae4211369 | ||
|
|
859cd7c920 | ||
|
|
d608c400f9 | ||
|
|
94e93bed83 | ||
|
|
b1cee140b9 | ||
|
|
352361bdd2 | ||
|
|
baa61468a1 | ||
|
|
7501ba2e45 | ||
|
|
200716e8fe | ||
|
|
50ef4909e3 | ||
|
|
63df4642b5 | ||
|
|
43869a499d | ||
|
|
d2bf3952ec | ||
|
|
92c380ee77 | ||
|
|
a55ba40921 | ||
|
|
fb1bfd03dd | ||
|
|
a0a7b3101d | ||
|
|
39dc4ba99c | ||
|
|
a5b5a8e5cf | ||
|
|
1daea78b91 | ||
|
|
6066eec853 | ||
|
|
cd379671aa | ||
|
|
8006223911 | ||
|
|
247f0bbcd3 | ||
|
|
3537420d91 | ||
|
|
65fb88e61e | ||
|
|
b345f48ac1 | ||
|
|
f181e12d8f | ||
|
|
36de6003d0 | ||
|
|
dba4de77bf | ||
|
|
507765625f | ||
|
|
8f5e5e8e7c | ||
|
|
c682a44bb6 | ||
|
|
cb7023681f | ||
|
|
012ef41ff4 | ||
|
|
f6bb5fa124 | ||
|
|
2489c76bc6 | ||
|
|
73cb96bf66 | ||
|
|
79ec61d1d8 | ||
|
|
ca440594fe | ||
|
|
6c25dd4aa2 | ||
|
|
09bb6bb03b | ||
|
|
746fdfbfef | ||
|
|
f7af9f1efd | ||
|
|
a5f95acaf5 | ||
|
|
e50b138ab2 | ||
|
|
3640c7a2dd | ||
|
|
2454bedf29 | ||
|
|
3adb2f50a6 | ||
|
|
01b7a93e08 | ||
|
|
347eaf582d | ||
|
|
25ca296477 | ||
|
|
3fce88555f | ||
|
|
9e6f27c9f1 | ||
|
|
94f01af545 | ||
|
|
432870cc36 | ||
|
|
e065907745 | ||
|
|
b7a5ca3d1e | ||
|
|
9569625f03 | ||
|
|
18afe37bd1 | ||
|
|
2b9777b812 | ||
|
|
8866ab1585 | ||
|
|
f0995164d9 | ||
|
|
136732afae | ||
|
|
3410eb82b3 | ||
|
|
794811fbdb | ||
|
|
abea22ec57 | ||
|
|
08beb0264a | ||
|
|
2e15b4842c | ||
|
|
6d95a2425c | ||
|
|
4667a3d66d | ||
|
|
0bf2477d2c | ||
|
|
71a752c971 | ||
|
|
358f237507 | ||
|
|
a966947220 | ||
|
|
16b060d9e9 | ||
|
|
ed7fde324e | ||
|
|
beb4e86b5f | ||
|
|
2036757b84 | ||
|
|
ed3ec045aa | ||
|
|
67d39a97f7 | ||
|
|
a4e187e138 | ||
|
|
9f380170d7 | ||
|
|
12f27f9cda |
@@ -26,7 +26,7 @@ Create changelog files for the important commits in this PR. The PR number is pr
|
||||
- `{PR_NUMBER}.performance.md` - for performance improvements
|
||||
- `{PR_NUMBER}.other.md` - for other changes
|
||||
|
||||
4. Each changelog file should at least contain a main single line starting with `- ` followed by a clear description of the change.
|
||||
4. Each changelog file should at least contain a main single line starting with `- ` followed by a clear description of the change. No line wrapping.
|
||||
|
||||
5. If the change is complicated, changelog files can have indented lines after the main line with additional details or code samples.
|
||||
|
||||
|
||||
250
.claude/skills/update-docs/SKILL.md
Normal file
250
.claude/skills/update-docs/SKILL.md
Normal file
@@ -0,0 +1,250 @@
|
||||
---
|
||||
name: update-docs
|
||||
description: Update documentation pages to match source code changes on the current branch
|
||||
---
|
||||
|
||||
Update documentation pages to reflect source code changes on the current branch. Analyzes the diff against main, maps changed source files to their corresponding doc pages, and makes targeted edits.
|
||||
|
||||
## Arguments
|
||||
|
||||
```
|
||||
/update-docs [DOCS_PATH]
|
||||
```
|
||||
|
||||
- `DOCS_PATH` (optional): Path to the docs repository root. If not provided, ask the user.
|
||||
|
||||
Examples:
|
||||
- `/update-docs /Users/me/src/docs`
|
||||
- `/update-docs`
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Resolve docs path
|
||||
|
||||
If `DOCS_PATH` was provided as an argument, use it. Otherwise, ask the user for the path to their docs repository.
|
||||
|
||||
Verify the path exists and contains `server/services/` subdirectory.
|
||||
|
||||
### Step 2: Create docs branch
|
||||
|
||||
Get the current pipecat branch name:
|
||||
```bash
|
||||
git rev-parse --abbrev-ref HEAD
|
||||
```
|
||||
|
||||
In the docs repo, create a new branch off main with a matching name:
|
||||
```bash
|
||||
cd DOCS_PATH && git checkout main && git pull && git checkout -b {branch-name}-docs
|
||||
```
|
||||
|
||||
For example, if the pipecat branch is `feat/new-service`, the docs branch becomes `feat/new-service-docs`.
|
||||
|
||||
All doc edits in subsequent steps are made on this branch.
|
||||
|
||||
### Step 3: Detect changed source files
|
||||
|
||||
Run:
|
||||
```bash
|
||||
git diff main..HEAD --name-only
|
||||
```
|
||||
|
||||
Filter to files that could affect documentation:
|
||||
- `src/pipecat/services/**/*.py` (service implementations)
|
||||
- `src/pipecat/transports/**/*.py` (transport implementations)
|
||||
- `src/pipecat/serializers/**/*.py` (serializer implementations)
|
||||
- `src/pipecat/processors/**/*.py` (processor implementations)
|
||||
- `src/pipecat/audio/**/*.py` (audio utilities)
|
||||
- `src/pipecat/turns/**/*.py` (turn management)
|
||||
- `src/pipecat/observers/**/*.py` (observers)
|
||||
- `src/pipecat/pipeline/**/*.py` (pipeline core)
|
||||
|
||||
Ignore `__init__.py`, `__pycache__`, test files, and files that only contain type re-exports.
|
||||
|
||||
### Step 4: Map source files to doc pages
|
||||
|
||||
For each changed source file, find the corresponding doc page. Read the mapping file at `.claude/skills/update-docs/SOURCE_DOC_MAPPING.md` and apply its tiered lookup: tier 1 (known exceptions) → tier 2 (pattern matching) → tier 3 (search fallback). **First match wins.**
|
||||
|
||||
### Step 5: Analyze each source-doc pair
|
||||
|
||||
For each mapped pair:
|
||||
|
||||
1. **Read the full source file** to understand current state
|
||||
2. **Read the diff** for that file: `git diff main..HEAD -- <source_file>`
|
||||
3. **Read the current doc page** in full
|
||||
|
||||
Identify what changed by comparing source to docs:
|
||||
|
||||
- **Constructor parameters**: Compare `__init__` signature to the Configuration section's `<ParamField>` entries
|
||||
- **InputParams fields**: Compare `InputParams(BaseModel)` class fields to the InputParams table
|
||||
- **Event handlers**: Compare `_register_event_handler` calls and event handler definitions to Event Handlers section
|
||||
- **Class names / imports**: Check if Usage examples reference correct names
|
||||
- **Behavioral changes**: Check if Notes section needs updating
|
||||
|
||||
### Step 6: Make targeted edits
|
||||
|
||||
For each doc page that needs updates, edit **only the sections that need changes**. Preserve all other content exactly as-is.
|
||||
|
||||
#### Rules
|
||||
|
||||
- **Never remove content** unless the corresponding source code was removed
|
||||
- **Never rewrite sections** that are already accurate
|
||||
- **Match existing formatting** — if the page uses `<ParamField>` tags, use them; if it uses tables, use tables
|
||||
- **Keep descriptions concise** — match the tone and length of surrounding content
|
||||
- **Preserve CardGroup, links, and examples** unless they reference removed functionality
|
||||
- **Don't touch frontmatter** unless the class was renamed
|
||||
|
||||
#### Section-specific guidance
|
||||
|
||||
**Configuration** (constructor params):
|
||||
- Use `<ParamField path="name" type="type" default="value">` format if the page already uses it
|
||||
- Add new params in logical order (required first, then optional)
|
||||
- Remove params that no longer exist in source
|
||||
- Update types/defaults that changed
|
||||
|
||||
**InputParams** (runtime settings):
|
||||
- Use markdown table format: `| Parameter | Type | Default | Description |`
|
||||
- Match the field names and types from the `InputParams(BaseModel)` class
|
||||
- Include the default values from the source
|
||||
|
||||
**Usage** (code examples):
|
||||
- Update import paths, class names, and parameter names
|
||||
- Only modify examples if they would break or be misleading with the new API
|
||||
- Don't rewrite working examples just to add new optional params
|
||||
|
||||
**Notes**:
|
||||
- Add notes for new behavioral gotchas or breaking changes
|
||||
- Remove notes about limitations that were fixed
|
||||
- Keep existing notes that are still accurate
|
||||
|
||||
**Event Handlers**:
|
||||
- Update the event table and example code
|
||||
- Add new events, remove deleted ones
|
||||
- Update handler signatures if they changed
|
||||
|
||||
**Overview / Key Features / Prerequisites**:
|
||||
- Only update if the PR fundamentally changes what the service does (new capability, removed capability, renamed class)
|
||||
- Most PRs will NOT need changes to these sections
|
||||
|
||||
### Step 7: Update guides
|
||||
|
||||
Guides at `DOCS_PATH/guides/` reference specific class names, parameters, imports, and code patterns. After completing reference doc edits, check if any guides need updates too.
|
||||
|
||||
For each changed source file, collect the class names, renamed parameters, and changed imports from the diff. Search the guides directory:
|
||||
```bash
|
||||
grep -rl "ClassName\|old_param_name" DOCS_PATH/guides/
|
||||
```
|
||||
|
||||
For each guide that references changed code:
|
||||
1. Read the full guide
|
||||
2. Update class names, parameter names, import paths, and code examples that are now incorrect
|
||||
3. **Don't rewrite prose** — only fix the specific references that changed
|
||||
4. Leave guides alone if they reference the service generally but don't use any changed APIs
|
||||
|
||||
Guide directories:
|
||||
- `guides/learn/` — conceptual tutorials (pipeline, LLM, STT, TTS, etc.)
|
||||
- `guides/fundamentals/` — practical how-tos (metrics, recording, transcripts, etc.)
|
||||
- `guides/features/` — feature-specific guides (Gemini Live, OpenAI audio, WhatsApp, etc.)
|
||||
- `guides/telephony/` — telephony integration guides (Twilio, Plivo, Telnyx, etc.)
|
||||
|
||||
### Step 8: Identify doc gaps
|
||||
|
||||
After processing all mapped pairs, check for two kinds of gaps:
|
||||
|
||||
**Missing pages**: Source files that had no doc page mapping (neither tier 1, 2, nor 3) and are not marked as "(skip)". For each, tell the user:
|
||||
- The source file path
|
||||
- The main class(es) it defines
|
||||
- Whether a new doc page should be created
|
||||
|
||||
**Missing sections**: Mapped doc pages that are missing standard sections compared to the source. For example, a transport page with no Configuration section, or a service page with no InputParams table when the source defines `InputParams(BaseModel)`. Flag these and offer to add the missing sections.
|
||||
|
||||
If the user wants a new page, create it using this template structure:
|
||||
```
|
||||
---
|
||||
title: "Service Name"
|
||||
description: "Brief description"
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
[Description from class docstring or source analysis]
|
||||
|
||||
<CardGroup cols={2}>
|
||||
[Cards for API reference and examples if available]
|
||||
</CardGroup>
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install "pipecat-ai[package-name]"
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
[Environment variables and account setup]
|
||||
|
||||
## Configuration
|
||||
|
||||
[ParamField entries for constructor params]
|
||||
|
||||
## InputParams
|
||||
|
||||
[Table of InputParams fields, if the service has them]
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
[Minimal working example]
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
[Important caveats]
|
||||
|
||||
## Event Handlers
|
||||
|
||||
[Event table and example code]
|
||||
```
|
||||
|
||||
### Step 9: Output summary
|
||||
|
||||
After all edits are complete, print a summary:
|
||||
|
||||
```
|
||||
## Documentation Updates
|
||||
|
||||
### Updated reference pages
|
||||
- `server/services/stt/deepgram.mdx` — Updated Configuration (added `new_param`), InputParams (updated `language` default)
|
||||
- `server/services/tts/elevenlabs.mdx` — Updated Event Handlers (added `on_connected`)
|
||||
|
||||
### Updated guides
|
||||
- `guides/learn/speech-to-text.mdx` — Updated code example (renamed `old_param` → `new_param`)
|
||||
|
||||
### Unmapped source files
|
||||
- `src/pipecat/services/newprovider/tts.py` — NewProviderTTSService (no doc page exists)
|
||||
|
||||
### Skipped files
|
||||
- `src/pipecat/services/ai_service.py` — internal base class
|
||||
```
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Be conservative** — only change what the diff warrants. Don't "improve" docs beyond what changed in source.
|
||||
- **Read before editing** — always read the full doc page before making changes so you understand the existing structure.
|
||||
- **Preserve voice** — match the writing style of the existing doc page, don't impose a different tone.
|
||||
- **One PR at a time** — this skill operates on the current branch's diff against main. Don't look at other branches.
|
||||
- **Parallel analysis** — when multiple source files map to different doc pages, analyze and edit them in parallel for efficiency.
|
||||
- **Shared source files** — files like `services/google/google.py` are shared bases. Check which services import from them and update all affected doc pages.
|
||||
|
||||
## Checklist
|
||||
|
||||
Before finishing, verify:
|
||||
|
||||
- [ ] All changed source files were checked against the mapping table
|
||||
- [ ] Each doc page edit matches the actual source code change (not guessed)
|
||||
- [ ] No content was removed unless the corresponding source was removed
|
||||
- [ ] New parameters have accurate types and defaults from source
|
||||
- [ ] Formatting matches the existing page style
|
||||
- [ ] Guides referencing changed APIs were checked and updated
|
||||
- [ ] Unmapped files were reported to the user
|
||||
79
.claude/skills/update-docs/SOURCE_DOC_MAPPING.md
Normal file
79
.claude/skills/update-docs/SOURCE_DOC_MAPPING.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# Source-to-Doc Mapping
|
||||
|
||||
Maps pipecat source files to their documentation pages. Source paths are relative to `src/pipecat/`. Doc paths are relative to `DOCS_PATH`.
|
||||
|
||||
## Name mismatches
|
||||
|
||||
These source paths don't follow the standard `services/{provider}/{type}.py` → `server/services/{type}/{provider}.mdx` pattern.
|
||||
|
||||
| Source path | Doc page |
|
||||
|---|---|
|
||||
| `services/google/llm.py` | `server/services/llm/gemini.mdx` |
|
||||
| `services/google/llm_vertex.py` | `server/services/llm/google-vertex.mdx` |
|
||||
| `services/google/google.py` | (shared base — check which services use it) |
|
||||
| `services/google/gemini_live/**` | `server/services/s2s/gemini-live.mdx` |
|
||||
| `services/google/gemini_live/llm_vertex.py` | `server/services/s2s/gemini-live-vertex.mdx` |
|
||||
| `services/aws_nova_sonic/**` | `server/services/s2s/aws.mdx` |
|
||||
| `services/ultravox/**` | `server/services/s2s/ultravox.mdx` |
|
||||
| `services/grok/realtime/**` | `server/services/s2s/grok.mdx` |
|
||||
| `services/openai/realtime/**` | `server/services/s2s/openai.mdx` |
|
||||
| `processors/frameworks/rtvi.py` | `server/frameworks/rtvi/rtvi-processor.mdx` and `server/frameworks/rtvi/rtvi-observer.mdx` |
|
||||
| `processors/transcript_processor.py` | `server/utilities/transcript-processor.mdx` |
|
||||
| `processors/user_idle_processor.py` | `server/utilities/user-idle-processor.mdx` |
|
||||
| `processors/idle_frame_processor.py` | `server/pipeline/pipeline-idle-detection.mdx` |
|
||||
| `pipeline/task.py` | `server/pipeline/pipeline-task.mdx` |
|
||||
| `pipeline/runner.py` | `server/utilities/runner/guide.mdx` |
|
||||
| `transports/base_transport.py` | `server/services/transport/transport-params.mdx` |
|
||||
|
||||
## Skip list
|
||||
|
||||
These files should never trigger doc updates.
|
||||
|
||||
| Pattern | Reason |
|
||||
|---|---|
|
||||
| `services/ai_service.py` | Internal base class |
|
||||
| `services/stt_service.py` | Internal base class |
|
||||
| `services/tts_service.py` | Internal base class |
|
||||
| `services/llm_service.py` | Internal base class |
|
||||
| `services/websocket_service.py` | Internal base class |
|
||||
| `services/openai_realtime_beta/**` | Deprecated |
|
||||
| `services/openai_realtime/**` | Deprecated |
|
||||
| `services/gemini_multimodal_live/**` | Deprecated |
|
||||
| `services/aws/agent_core.py` | Internal |
|
||||
| `services/aws/sagemaker/**` | No doc page |
|
||||
| `transports/base_input.py` | Internal base class |
|
||||
| `transports/base_output.py` | Internal base class |
|
||||
| `transports/websocket/client.py` | No doc page |
|
||||
| `serializers/base_serializer.py` | Internal base class |
|
||||
| `serializers/protobuf.py` | Internal |
|
||||
| `processors/audio/**` | Internal |
|
||||
| `pipeline/pipeline.py` | Core architecture, not a service doc |
|
||||
|
||||
## Pattern matching
|
||||
|
||||
For files not in the tables above, apply these patterns. Convert underscores to hyphens in provider names for doc filenames.
|
||||
|
||||
| Source pattern | Doc pattern |
|
||||
|---|---|
|
||||
| `services/{provider}/stt*.py` | `server/services/stt/{provider}.mdx` |
|
||||
| `services/{provider}/tts*.py` | `server/services/tts/{provider}.mdx` |
|
||||
| `services/{provider}/llm*.py` | `server/services/llm/{provider}.mdx` |
|
||||
| `services/{provider}/image*.py` | `server/services/image-generation/{provider}.mdx` |
|
||||
| `services/{provider}/video*.py` | `server/services/video/{provider}.mdx` |
|
||||
| `services/{provider}/realtime/**` | `server/services/s2s/{provider}.mdx` |
|
||||
| `transports/{name}/**` | `server/services/transport/{name}.mdx` |
|
||||
| `serializers/{name}.py` | `server/services/serializers/{name}.mdx` |
|
||||
| `observers/**` | `server/utilities/observers/` (match by class name) |
|
||||
| `audio/vad/**` | `server/utilities/audio/` (match by class name) |
|
||||
| `audio/filters/**` | `server/utilities/audio/` (match by class name) |
|
||||
| `audio/mixers/**` | `server/utilities/audio/` (match by class name) |
|
||||
| `processors/filters/**` | `server/utilities/filters/` (match by class name) |
|
||||
|
||||
If the doc file doesn't exist at the resolved path, the file is **unmapped**.
|
||||
|
||||
## Search fallback
|
||||
|
||||
For files that don't match any table or pattern above:
|
||||
1. Extract the main class name(s) from the source file
|
||||
2. Search the docs directory for that class name: `grep -r "ClassName" DOCS_PATH/server/`
|
||||
3. If found in a doc page, use that as the mapping
|
||||
2
.github/workflows/coverage.yaml
vendored
2
.github/workflows/coverage.yaml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
|
||||
- name: Install system packages
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y portaudio19-dev
|
||||
|
||||
- name: Install dependencies
|
||||
@@ -41,6 +42,7 @@ jobs:
|
||||
--extra livekit \
|
||||
--extra local-smart-turn-v3 \
|
||||
--extra piper \
|
||||
--extra tracing \
|
||||
--extra websocket
|
||||
|
||||
- name: Run tests with coverage
|
||||
|
||||
2
.github/workflows/generate-changelog.yml
vendored
2
.github/workflows/generate-changelog.yml
vendored
@@ -86,7 +86,7 @@ jobs:
|
||||
fi
|
||||
|
||||
# Validate fragment types
|
||||
VALID_TYPES="added changed deprecated removed fixed security other"
|
||||
VALID_TYPES="added changed deprecated removed fixed performance security other"
|
||||
INVALID_FRAGMENTS=""
|
||||
|
||||
for file in changelog/*.md; do
|
||||
|
||||
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@@ -33,6 +33,7 @@ jobs:
|
||||
|
||||
- name: Install system packages
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y portaudio19-dev
|
||||
|
||||
- name: Install dependencies
|
||||
@@ -45,6 +46,7 @@ jobs:
|
||||
--extra livekit \
|
||||
--extra local-smart-turn-v3 \
|
||||
--extra piper \
|
||||
--extra tracing \
|
||||
--extra websocket
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
209
CHANGELOG.md
209
CHANGELOG.md
@@ -7,6 +7,215 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
<!-- towncrier release notes start -->
|
||||
|
||||
## [0.0.103] - 2026-02-20
|
||||
|
||||
### Added
|
||||
|
||||
- Added `"timestampTransportStrategy": "ASYNC"` to `InworldAITTSService`. This
|
||||
allows timestamps info to trail audio chunks arrival, resulting in much
|
||||
better first audio chunk latency
|
||||
(PR [#3625](https://github.com/pipecat-ai/pipecat/pull/3625))
|
||||
|
||||
- Added model-specific `InputParams` to `RimeTTSService`: arcana params
|
||||
(`repetition_penalty`, `temperature`, `top_p`) and mistv2 params
|
||||
(`no_text_normalization`, `save_oovs`, `segment`). Model, voice, and param
|
||||
changes now trigger WebSocket reconnection.
|
||||
(PR [#3642](https://github.com/pipecat-ai/pipecat/pull/3642))
|
||||
|
||||
- Added `write_transport_frame()` hook to `BaseOutputTransport` allowing
|
||||
transport subclasses to handle custom frame types that flow through the audio
|
||||
queue.
|
||||
(PR [#3719](https://github.com/pipecat-ai/pipecat/pull/3719))
|
||||
|
||||
- Added `DailySIPTransferFrame` and `DailySIPReferFrame` to the Daily
|
||||
transport. These frames queue SIP transfer and SIP REFER operations with
|
||||
audio, so the operation executes only after the bot finishes its current
|
||||
utterance.
|
||||
(PR [#3719](https://github.com/pipecat-ai/pipecat/pull/3719))
|
||||
|
||||
- Added keepalive support to `SarvamSTTService` to prevent idle connection
|
||||
timeouts (e.g. when used behind a `ServiceSwitcher`).
|
||||
(PR [#3730](https://github.com/pipecat-ai/pipecat/pull/3730))
|
||||
|
||||
- Added `UserIdleTimeoutUpdateFrame` to enable or disable user idle detection
|
||||
at runtime by updating the timeout dynamically.
|
||||
(PR [#3748](https://github.com/pipecat-ai/pipecat/pull/3748))
|
||||
|
||||
- Added `broadcast_sibling_id` field to the base `Frame` class. This field is
|
||||
automatically set by `broadcast_frame()` and `broadcast_frame_instance()` to
|
||||
the ID of the paired frame pushed in the opposite direction, allowing
|
||||
receivers to identify broadcast pairs.
|
||||
(PR [#3774](https://github.com/pipecat-ai/pipecat/pull/3774))
|
||||
|
||||
- Added `ignored_sources` parameter to `RTVIObserverParams` and
|
||||
`add_ignored_source()`/`remove_ignored_source()` methods to `RTVIObserver` to
|
||||
suppress RTVI messages from specific pipeline processors (e.g. a silent
|
||||
evaluation LLM).
|
||||
(PR [#3779](https://github.com/pipecat-ai/pipecat/pull/3779))
|
||||
|
||||
- Added `DeepgramSageMakerTTSService` for running Deepgram TTS models deployed
|
||||
on AWS SageMaker endpoints via HTTP/2 bidirectional streaming. Supports the
|
||||
Deepgram TTS protocol (Speak, Flush, Clear, Close), interruption handling,
|
||||
and per-turn TTFB metrics.
|
||||
(PR [#3785](https://github.com/pipecat-ai/pipecat/pull/3785))
|
||||
|
||||
### Changed
|
||||
|
||||
- ⚠️ `RimeTTSService` now defaults to `model="arcana"` and the
|
||||
`wss://users-ws.rime.ai/ws3` endpoint. `InputParams` defaults changed from
|
||||
mistv2-specific values to `None` — only explicitly-set params are sent as
|
||||
query params.
|
||||
(PR [#3642](https://github.com/pipecat-ai/pipecat/pull/3642))
|
||||
|
||||
- `AICFilter` now shares read-only AIC models via a singleton `AICModelManager`
|
||||
in `aic_filter.py`.
|
||||
- Multiple filters using the same model path or `(model_id,
|
||||
model_download_dir)` share one loaded model, with reference counting and
|
||||
concurrent load deduplication.
|
||||
- Model file I/O runs off the event loop so the filter does not block.
|
||||
(PR [#3684](https://github.com/pipecat-ai/pipecat/pull/3684))
|
||||
|
||||
- Added `X-User-Agent` and `X-Request-Id` headers to `InworldTTSService` for
|
||||
better traceability.
|
||||
(PR [#3706](https://github.com/pipecat-ai/pipecat/pull/3706))
|
||||
|
||||
- `DailyUpdateRemoteParticipantsFrame` is no longer deprecated and is now
|
||||
queued with audio like other transport frames.
|
||||
(PR [#3719](https://github.com/pipecat-ai/pipecat/pull/3719))
|
||||
|
||||
- Bumped Pillow dependency upper bound from `<12` to `<13` to allow Pillow
|
||||
12.x.
|
||||
(PR [#3728](https://github.com/pipecat-ai/pipecat/pull/3728))
|
||||
|
||||
- Moved STT keepalive mechanism from `WebsocketSTTService` to the `STTService`
|
||||
base class, allowing any STT service (not just websocket-based ones) to use
|
||||
idle-connection keepalive via the `keepalive_timeout` and
|
||||
`keepalive_interval` parameters.
|
||||
(PR [#3730](https://github.com/pipecat-ai/pipecat/pull/3730))
|
||||
|
||||
- Improved audio context management in `AudioContextTTSService` by moving
|
||||
context ID tracking to the base class and adding
|
||||
`reuse_context_id_within_turn` parameter to control concurrent TTS request
|
||||
handling.
|
||||
- Added helper methods: `has_active_audio_context()`,
|
||||
`get_active_audio_context_id()`, `remove_active_audio_context()`,
|
||||
`reset_active_audio_context()`
|
||||
- Simplified Cartesia, ElevenLabs, Inworld, Rime, AsyncAI, and Gradium TTS
|
||||
implementations by removing duplicate context management code
|
||||
(PR [#3732](https://github.com/pipecat-ai/pipecat/pull/3732))
|
||||
|
||||
- `UserIdleController` is now always created with a default timeout of 0
|
||||
(disabled). The `user_idle_timeout` parameter changed from `Optional[float] =
|
||||
None` to `float = 0` in `UserTurnProcessor`, `LLMUserAggregatorParams`, and
|
||||
`UserIdleController`.
|
||||
(PR [#3748](https://github.com/pipecat-ai/pipecat/pull/3748))
|
||||
|
||||
- Change the version specifier from `>=0.2.8` to `~=0.2.8` for the
|
||||
`speechmatics-voice` package to ensure compatibility with future patch
|
||||
versions.
|
||||
(PR [#3761](https://github.com/pipecat-ai/pipecat/pull/3761))
|
||||
|
||||
- Updated `InworldTTSService` and `InworldHttpTTSService` to use `ASYNC`
|
||||
timestamp transport strategy by default
|
||||
(PR [#3765](https://github.com/pipecat-ai/pipecat/pull/3765))
|
||||
|
||||
- Added `start_time` and `end_time` parameters to `start_ttfb_metrics()`,
|
||||
`stop_ttfb_metrics()`, `start_processing_metrics()`, and
|
||||
`stop_processing_metrics()` in `FrameProcessor` and `FrameProcessorMetrics`,
|
||||
allowing custom timestamps for metrics measurement. `STTService` now uses
|
||||
these instead of custom TTFB tracking.
|
||||
(PR [#3776](https://github.com/pipecat-ai/pipecat/pull/3776))
|
||||
|
||||
- Updated default Anthropic model from `claude-sonnet-4-5-20250929` to
|
||||
`claude-sonnet-4-6`.
|
||||
(PR [#3792](https://github.com/pipecat-ai/pipecat/pull/3792))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated unused `Traceable`, `@traceable`, `@traced`, and
|
||||
`AttachmentStrategy` in `pipecat.utils.tracing.class_decorators`. This module
|
||||
will be removed in a future release.
|
||||
(PR [#3733](https://github.com/pipecat-ai/pipecat/pull/3733))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed race condition where `RTVIObserver` could send messages before
|
||||
`DailyTransport` join completed. Outbound messages are now queued & delivered
|
||||
after the transport is ready.
|
||||
(PR [#3615](https://github.com/pipecat-ai/pipecat/pull/3615))
|
||||
|
||||
- Fixed async generator cleanup in OpenAI LLM streaming to prevent
|
||||
`AttributeError` with uvloop on Python 3.12+ (MagicStack/uvloop#699).
|
||||
(PR [#3698](https://github.com/pipecat-ai/pipecat/pull/3698))
|
||||
|
||||
- Fixed `SmallWebRTCTransport` input audio resampling to properly handle all
|
||||
sample rates, including 8kHz audio.
|
||||
(PR [#3713](https://github.com/pipecat-ai/pipecat/pull/3713))
|
||||
|
||||
- Fixed a race condition in `RTVIObserver` where bot output messages could be
|
||||
sent before the bot-started-speaking event.
|
||||
(PR [#3718](https://github.com/pipecat-ai/pipecat/pull/3718))
|
||||
|
||||
- Fixed Grok Realtime `session.updated` event parsing failure caused by the API
|
||||
returning prefixed voice names (e.g. `"human_Ara"` instead of `"Ara"`).
|
||||
(PR [#3720](https://github.com/pipecat-ai/pipecat/pull/3720))
|
||||
|
||||
- Fixed context ID reuse issue in `ElevenLabsTTSService`, `InworldTTSService`,
|
||||
`RimeTTSService`, `CartesiaTTSService`, `AsyncAITTSService`, and
|
||||
`PlayHTTTSService`. Services now properly reuse the same context ID across
|
||||
multiple `run_tts()` invocations within a single LLM turn, preventing context
|
||||
tracking issues and incorrect lifecycle signaling.
|
||||
(PR [#3729](https://github.com/pipecat-ai/pipecat/pull/3729))
|
||||
|
||||
- Fixed word timestamp interleaving issue in `ElevenLabsTTSService` when
|
||||
processing multiple sentences within a single LLM turn.
|
||||
(PR [#3729](https://github.com/pipecat-ai/pipecat/pull/3729))
|
||||
|
||||
- Fixed tracing service decorators executing the wrapped function twice when
|
||||
the function itself raised an exception (e.g., LLM rate limit, TTS timeout).
|
||||
(PR [#3735](https://github.com/pipecat-ai/pipecat/pull/3735))
|
||||
|
||||
- Fixed `LLMUserAggregator` broadcasting mute events before `StartFrame`
|
||||
reaches downstream processors.
|
||||
(PR [#3737](https://github.com/pipecat-ai/pipecat/pull/3737))
|
||||
|
||||
- Fixed `UserIdleController` false idle triggers caused by gaps between user
|
||||
and bot activity frames. The idle timer now starts only after
|
||||
`BotStoppedSpeakingFrame` and is suppressed during active user turns and
|
||||
function calls.
|
||||
(PR [#3744](https://github.com/pipecat-ai/pipecat/pull/3744))
|
||||
|
||||
- Fixed incorrect `sample_rate` assignment in
|
||||
`TavusInputTransport._on_participant_audio_data` (was using
|
||||
`audio.audio_frames` instead of `audio.sample_rate`).
|
||||
(PR [#3768](https://github.com/pipecat-ai/pipecat/pull/3768))
|
||||
|
||||
- Fixed `RTVIObserver` not processing upstream-only frames. Previously, all
|
||||
upstream frames were filtered out to avoid duplicate messages from
|
||||
broadcasted frames. Now only upstream copies of broadcasted frames are
|
||||
skipped.
|
||||
(PR [#3774](https://github.com/pipecat-ai/pipecat/pull/3774))
|
||||
|
||||
- Fixed mutable default arguments in `LLMContextAggregatorPair.__init__()` that
|
||||
could cause shared state across instances.
|
||||
(PR [#3782](https://github.com/pipecat-ai/pipecat/pull/3782))
|
||||
|
||||
- Fixed `DeepgramSageMakerSTTService` to properly track finalize lifecycle
|
||||
using `request_finalize()` / `confirm_finalize()` and use `is_final` (instead
|
||||
of `is_final and speech_final`) for final transcription detection, matching
|
||||
`DeepgramSTTService` behavior.
|
||||
(PR [#3784](https://github.com/pipecat-ai/pipecat/pull/3784))
|
||||
|
||||
- Fixed a race condition in `AudioContextTTSService` where the audio context
|
||||
could time out between consecutive TTS requests within the same turn, causing
|
||||
audio to be discarded.
|
||||
(PR [#3787](https://github.com/pipecat-ai/pipecat/pull/3787))
|
||||
|
||||
- Fixed `push_interruption_task_frame_and_wait()` hanging indefinitely when the
|
||||
`InterruptionFrame` does not reach the pipeline sink within the timeout.
|
||||
Added a `timeout` keyword argument to customize the wait duration.
|
||||
(PR [#3789](https://github.com/pipecat-ai/pipecat/pull/3789))
|
||||
|
||||
## [0.0.102] - 2026-02-10
|
||||
|
||||
### Added
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
- Added `X-User-Agent` and `X-Request-Id` headers to `InworldTTSService` for better traceability.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed `SmallWebRTCTransport` input audio resampling to properly handle all sample rates, including 8kHz audio.
|
||||
@@ -1,5 +0,0 @@
|
||||
- Fixed a race condition in `SpeechTimeoutUserTurnStopStrategy` where a finalized
|
||||
transcript arriving after `user_speech_timeout` elapsed from VAD stop would
|
||||
immediately trigger a turn stop, even if the user was still speaking. STT
|
||||
processing latency was consuming the `user_speech_timeout` window, leaving no
|
||||
time for the user to resume speaking.
|
||||
1
changelog/3759.performance.md
Normal file
1
changelog/3759.performance.md
Normal file
@@ -0,0 +1 @@
|
||||
- Switched `GradiumTTSService` from `InterruptibleWordTTSService` to `AudioContextWordTTSService`, eliminating websocket disconnect/reconnect on every interruption by using `client_req_id`-based multiplexing.
|
||||
1
changelog/3802.fixed.md
Normal file
1
changelog/3802.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed self-referential `pipecat-ai[local-smart-turn-v3]` dependency in `pyproject.toml` that caused Poetry 2.x to fail with a circular dependency error. The underlying packages (`transformers`, `onnxruntime`) are now listed directly in main dependencies.
|
||||
@@ -47,7 +47,8 @@ DAILY_ROOM_URL=https://...
|
||||
|
||||
# Deepgram
|
||||
DEEPGRAM_API_KEY=...
|
||||
SAGEMAKER_ENDPOINT_NAME=...
|
||||
SAGEMAKER_STT_ENDPOINT_NAME=...
|
||||
SAGEMAKER_TTS_ENDPOINT_NAME=...
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=...
|
||||
|
||||
@@ -24,7 +24,7 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.deepgram.tts_sagemaker import DeepgramSageMakerTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -58,11 +58,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# - AWS credentials configured (via environment variables or AWS CLI)
|
||||
# - A deployed SageMaker endpoint with Deepgram model
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME"),
|
||||
endpoint_name=os.getenv("SAGEMAKER_STT_ENDPOINT_NAME"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
)
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
# Initialize Deepgram SageMaker TTS Service
|
||||
# This requires:
|
||||
# - AWS credentials configured (via environment variables or AWS CLI)
|
||||
# - A deployed SageMaker endpoint with Deepgram TTS model
|
||||
tts = DeepgramSageMakerTTSService(
|
||||
endpoint_name=os.getenv("SAGEMAKER_TTS_ENDPOINT_NAME"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
voice="aura-2-andromeda-en",
|
||||
)
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region=os.getenv("AWS_REGION"),
|
||||
|
||||
@@ -56,7 +56,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
tts = RimeTTSService(
|
||||
api_key=os.getenv("RIME_API_KEY", ""),
|
||||
voice_id="rex",
|
||||
voice_id="luna",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
@@ -96,7 +96,7 @@ class UserAudioCollector(FrameProcessor):
|
||||
self._user_speaking = True
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
self._context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self._context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self._user_context_aggregator.push_frame(LLMRunFrame())
|
||||
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
|
||||
@@ -72,10 +72,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-7-sonnet-latest",
|
||||
)
|
||||
llm = AnthropicLLMService(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
|
||||
@@ -5,17 +5,21 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
EndTaskFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMRunFrame,
|
||||
TTSSpeakFrame,
|
||||
UserIdleTimeoutUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -30,6 +34,7 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -74,6 +79,17 @@ class IdleHandler:
|
||||
await aggregator.push_frame(EndTaskFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
# Simulate a slow API call, waiting longer than the user idle timeout.
|
||||
await asyncio.sleep(3)
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await asyncio.sleep(6)
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
@@ -104,6 +120,42 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -111,7 +163,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context = LLMContext(messages, tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
@@ -146,6 +198,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(aggregator):
|
||||
logger.info(f"User turn idle")
|
||||
await idle_handler.handle_idle(aggregator)
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_started")
|
||||
@@ -158,6 +211,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
await asyncio.sleep(30)
|
||||
logger.info(f"Disabling idle detection")
|
||||
await task.queue_frames([UserIdleTimeoutUpdateFrame(timeout=0)])
|
||||
await asyncio.sleep(30)
|
||||
logger.info(f"Enabling idle detection")
|
||||
await task.queue_frames([UserIdleTimeoutUpdateFrame(timeout=5)])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -98,7 +98,7 @@ class UserAudioCollector(FrameProcessor):
|
||||
self._user_speaking = True
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
self._context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self._context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self._user_context_aggregator.push_frame(LLMContextFrame(context=self._context))
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if self._user_speaking:
|
||||
|
||||
191
examples/foundational/53-concurrent-llm-rtvi-ignored-sources.py
Normal file
191
examples/foundational/53-concurrent-llm-rtvi-ignored-sources.py
Normal file
@@ -0,0 +1,191 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""RTVIObserver ignored sources example.
|
||||
|
||||
This example shows how to suppress RTVI messages from a specific pipeline
|
||||
processor so that secondary branches don't leak events to the client.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.audio.vad_processor import VADProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserverParams
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_turn_processor import UserTurnProcessor
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Main LLM — drives the conversation. Its RTVI events reach the client.
|
||||
main_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
main_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
# Evaluator LLM — silently grades the user's message in the background.
|
||||
# Its RTVI events will be suppressed so the client is unaware of this branch.
|
||||
evaluator_llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
name="EvaluatorLLM",
|
||||
)
|
||||
|
||||
evaluator_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a silent quality evaluator. When given a user message, "
|
||||
"respond with a single JSON object: "
|
||||
'{"score": <1-5>, "reason": "<brief reason>"}. '
|
||||
"Do not respond conversationally."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
main_context = LLMContext(main_messages)
|
||||
evaluator_context = LLMContext(evaluator_messages)
|
||||
|
||||
# We use an external VADProcessor because the UserTurnProcessor is shared
|
||||
# across multiple parallel aggregators. The VADProcessor emits
|
||||
# VADUserStartedSpeakingFrame and VADUserStoppedSpeakingFrame which the
|
||||
# UserTurnProcessor needs to manage turn lifecycle.
|
||||
vad_processor = VADProcessor(vad_analyzer=SileroVADAnalyzer())
|
||||
|
||||
# We use this external user turn processor. This processor will push
|
||||
# UserStartedSpeakingFrame and UserStoppedSpeakingFrame as well as
|
||||
# interruptions. This can be used in advanced cases when there are multiple
|
||||
# aggregators in the pipeline.
|
||||
user_turn_processor = UserTurnProcessor()
|
||||
|
||||
# We use external user turn strategies for both aggregators since the turn
|
||||
# management is done by the common UserTurnProcessor.
|
||||
main_context_aggregator = LLMContextAggregatorPair(
|
||||
main_context,
|
||||
user_params=LLMUserAggregatorParams(user_turn_strategies=ExternalUserTurnStrategies()),
|
||||
)
|
||||
evaluator_context_aggregator = LLMContextAggregatorPair(
|
||||
evaluator_context,
|
||||
user_params=LLMUserAggregatorParams(user_turn_strategies=ExternalUserTurnStrategies()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
vad_processor,
|
||||
user_turn_processor,
|
||||
ParallelPipeline(
|
||||
# Main branch: speaks to the user.
|
||||
[
|
||||
main_context_aggregator.user(),
|
||||
main_llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
main_context_aggregator.assistant(),
|
||||
],
|
||||
# Evaluator branch: silent background scoring, no audio output.
|
||||
[
|
||||
evaluator_context_aggregator.user(),
|
||||
evaluator_llm,
|
||||
evaluator_context_aggregator.assistant(),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
rtvi_observer_params=RTVIObserverParams(ignored_sources=[evaluator_llm]),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
main_messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
evaluator_messages.append({"role": "system", "content": "Ready to evaluate user messages."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -28,7 +28,7 @@ dependencies = [
|
||||
"Markdown>=3.7,<4",
|
||||
"nltk>=3.9.1,<4",
|
||||
"numpy>=1.26.4,<3",
|
||||
"Pillow>=11.1.0,<12",
|
||||
"Pillow>=11.1.0,<13",
|
||||
"protobuf~=5.29.6",
|
||||
"pydantic>=2.10.6,<3",
|
||||
"pyloudnorm~=0.1.1",
|
||||
@@ -38,8 +38,9 @@ dependencies = [
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
"wait_for2>=0.4.1; python_version<'3.12'",
|
||||
# Pipecat optionals
|
||||
"pipecat-ai[local-smart-turn-v3]",
|
||||
# Local smart turn v3 (inlined to avoid self-referential dependency)
|
||||
"transformers",
|
||||
"onnxruntime~=1.23.2",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -105,16 +106,16 @@ remote-smart-turn = []
|
||||
resembleai = [ "pipecat-ai[websockets-base]" ]
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "pipecat-ai[nvidia]" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.128.0", "pipecat-ai-small-webrtc-prebuilt>=2.1.0"]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.128.0", "pipecat-ai-small-webrtc-prebuilt>=2.2.0"]
|
||||
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
|
||||
sambanova = []
|
||||
sarvam = [ "sarvamai==0.1.21", "pipecat-ai[websockets-base]" ]
|
||||
sarvam = [ "sarvamai==0.1.26a2", "pipecat-ai[websockets-base]" ]
|
||||
sentry = [ "sentry-sdk>=2.28.0,<3" ]
|
||||
silero = [ "onnxruntime~=1.23.2" ]
|
||||
simli = [ "simli-ai~=1.0.3"]
|
||||
simli = [ "simli-ai~=2.0.1"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.1" ]
|
||||
speechmatics = [ "speechmatics-voice[smart]>=0.2.8" ]
|
||||
speechmatics = [ "speechmatics-voice[smart]~=0.2.8" ]
|
||||
strands = [ "strands-agents>=1.9.1,<2" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
|
||||
@@ -12,10 +12,13 @@ the Koala filter and integrates with Pipecat's input transport pipeline.
|
||||
|
||||
Classes:
|
||||
AICFilter: For aic-sdk (uses 'aic_sdk' module)
|
||||
AICModelManager: Singleton manager for read-only AIC Model instances.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from aic_sdk import (
|
||||
@@ -33,6 +36,177 @@ from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
|
||||
|
||||
class AICModelManager:
|
||||
"""Singleton manager for read-only AIC Model instances with reference counting.
|
||||
|
||||
Caches Model instances by path or (model_id + download_dir). Multiple
|
||||
AICFilter instances using the same model share one Model; the manager
|
||||
acquires on first use and releases when the last reference is dropped.
|
||||
"""
|
||||
|
||||
_cache: dict[str, Tuple[Model, int]] = {} # key -> (model, ref_count)
|
||||
_lock = Lock()
|
||||
_loading: dict[
|
||||
str, asyncio.Task[Model]
|
||||
] = {} # key -> load task (deduplicates concurrent loads)
|
||||
|
||||
@classmethod
|
||||
def _increment_reference(cls, cache_key: str, entry: Tuple[Model, int]) -> Tuple[Model, str]:
|
||||
"""Increment reference count for cached entry. Caller must hold _lock."""
|
||||
cached_model, ref_count = entry
|
||||
cls._cache[cache_key] = (cached_model, ref_count + 1)
|
||||
logger.debug(f"AIC model cache key={cache_key!r} ref_count={ref_count + 1}")
|
||||
return cached_model, cache_key
|
||||
|
||||
@classmethod
|
||||
def _store_new_reference(cls, cache_key: str, model: Model) -> Tuple[Model, str]:
|
||||
"""Store new model in cache with ref count 1. Caller must hold _lock."""
|
||||
cls._cache[cache_key] = (model, 1)
|
||||
logger.debug(f"AIC model cached key={cache_key!r} ref_count=1")
|
||||
return model, cache_key
|
||||
|
||||
@classmethod
|
||||
async def _load_model_from_file(
|
||||
cls,
|
||||
cache_key: str,
|
||||
*,
|
||||
model_path: Optional[Path] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> Model:
|
||||
"""Run the actual load (file or download). Separate to allow create_task and deduplication."""
|
||||
if model_path is not None:
|
||||
logger.debug(f"Loading AIC model from file: {model_path}")
|
||||
model_path_str = str(model_path)
|
||||
|
||||
elif model_id is not None and model_download_dir is not None:
|
||||
logger.debug(f"Downloading AIC model: {model_id}")
|
||||
model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path_str = await Model.download_async(model_id, str(model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {model_path_str}")
|
||||
|
||||
else:
|
||||
raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: Model.from_file(model_path_str))
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_key(
|
||||
*,
|
||||
model_path: Optional[Path] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> str:
|
||||
"""Build a stable cache key for the model.
|
||||
|
||||
Args:
|
||||
model_path: Path to a local .aicmodel file.
|
||||
model_id: Model identifier (See https://artifacts.ai-coustics.io/ for available models).
|
||||
model_download_dir: Directory used for downloading models.
|
||||
|
||||
Returns:
|
||||
A string key unique per (path) or (model_id + download_dir).
|
||||
"""
|
||||
if model_path is not None:
|
||||
return f"path:{model_path.resolve()}"
|
||||
|
||||
if model_id is not None and model_download_dir is not None:
|
||||
return f"id:{model_id}:{model_download_dir.resolve()}"
|
||||
|
||||
raise ValueError("Either model_path or (model_id and model_download_dir) must be set.")
|
||||
|
||||
@classmethod
|
||||
async def acquire(
|
||||
cls,
|
||||
*,
|
||||
model_path: Optional[Path] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> Tuple[Model, str]:
|
||||
"""Get or load a Model and increment its reference count.
|
||||
|
||||
Call this when starting a filter. Store the returned key and pass it
|
||||
to release() when stopping the filter.
|
||||
|
||||
Args:
|
||||
model_path: Path to a local .aicmodel file. If set, model_id is ignored.
|
||||
model_id: Model identifier to download from CDN.
|
||||
model_download_dir: Directory for downloading models. Required if
|
||||
model_id is used.
|
||||
|
||||
Returns:
|
||||
Tuple of (shared Model instance, cache key for release).
|
||||
|
||||
Raises:
|
||||
ValueError: If neither model_path nor (model_id + model_download_dir)
|
||||
is provided, or if model_id is set without model_download_dir.
|
||||
"""
|
||||
cache_key = cls._get_cache_key(
|
||||
model_path=model_path,
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
|
||||
with cls._lock:
|
||||
entry = cls._cache.get(cache_key)
|
||||
if entry is not None:
|
||||
return cls._increment_reference(cache_key, entry)
|
||||
|
||||
# Deduplicate concurrent loads for the same key
|
||||
load_task = cls._loading.get(cache_key)
|
||||
if load_task is None:
|
||||
load_task = asyncio.create_task(
|
||||
cls._load_model_from_file(
|
||||
cache_key,
|
||||
model_path=model_path,
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
)
|
||||
cls._loading[cache_key] = load_task
|
||||
|
||||
try:
|
||||
model = await load_task
|
||||
finally:
|
||||
with cls._lock:
|
||||
cls._loading.pop(cache_key, None)
|
||||
|
||||
with cls._lock:
|
||||
entry = cls._cache.get(cache_key)
|
||||
if entry is not None:
|
||||
return cls._increment_reference(cache_key, entry)
|
||||
return cls._store_new_reference(cache_key, model)
|
||||
|
||||
@classmethod
|
||||
def release(cls, key: str) -> None:
|
||||
"""Release a reference to a cached model.
|
||||
|
||||
Call this when stopping a filter, with the key returned from
|
||||
get_model(). When the last reference is released, the model
|
||||
is removed from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key returned by get_model().
|
||||
"""
|
||||
with cls._lock:
|
||||
entry = cls._cache.get(key)
|
||||
|
||||
if entry is None:
|
||||
logger.warning(f"AIC model release unknown key={key!r}")
|
||||
return
|
||||
|
||||
model, ref_count = entry
|
||||
ref_count -= 1
|
||||
|
||||
if ref_count <= 0:
|
||||
del cls._cache[key]
|
||||
logger.debug(f"AIC model evicted key={key!r}")
|
||||
else:
|
||||
cls._cache[key] = (model, ref_count)
|
||||
logger.debug(f"AIC model key={key!r} ref_count={ref_count}")
|
||||
|
||||
|
||||
class AICFilter(BaseAudioFilter):
|
||||
"""Audio filter using ai-coustics' AIC SDK for real-time enhancement.
|
||||
|
||||
@@ -91,7 +265,8 @@ class AICFilter(BaseAudioFilter):
|
||||
32768.0 # 2^15, for normalizing int16 (-32768 to 32767) to float32 (-1.0 to 1.0)
|
||||
)
|
||||
|
||||
# AIC SDK objects
|
||||
# AIC SDK objects; model is shared via AICModelManager
|
||||
self._model_cache_key: Optional[str] = None
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._processor_ctx = None
|
||||
@@ -162,16 +337,12 @@ class AICFilter(BaseAudioFilter):
|
||||
"""
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
# Load or download model
|
||||
if self._model_path:
|
||||
logger.debug(f"Loading AIC model from: {self._model_path}")
|
||||
self._model = Model.from_file(str(self._model_path))
|
||||
else:
|
||||
logger.debug(f"Downloading AIC model: {self._model_id}")
|
||||
self._model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {model_path}")
|
||||
self._model = Model.from_file(model_path)
|
||||
# Acquire shared read-only model from singleton manager
|
||||
self._model, self._model_cache_key = await AICModelManager.acquire(
|
||||
model_path=self._model_path,
|
||||
model_id=self._model_id,
|
||||
model_download_dir=self._model_download_dir,
|
||||
)
|
||||
|
||||
# Get optimal frames for this sample rate
|
||||
self._frames_per_block = self._model.get_optimal_num_frames(self._sample_rate)
|
||||
@@ -242,6 +413,10 @@ class AICFilter(BaseAudioFilter):
|
||||
self._aic_ready = False
|
||||
self._audio_buffer.clear()
|
||||
|
||||
if self._model_cache_key is not None:
|
||||
AICModelManager.release(self._model_cache_key)
|
||||
self._model_cache_key = None
|
||||
|
||||
async def process_frame(self, frame: FilterControlFrame):
|
||||
"""Process control frames to enable/disable filtering.
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from pipecat.utils.utils import obj_count, obj_id
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.utils.tracing.tracing_context import TracingContext
|
||||
|
||||
|
||||
class DeprecatedKeypadEntry:
|
||||
@@ -122,6 +123,9 @@ class Frame:
|
||||
id: Unique identifier for the frame instance.
|
||||
name: Human-readable name combining class name and instance count.
|
||||
pts: Presentation timestamp in nanoseconds.
|
||||
broadcast_sibling_id: ID of the paired frame when this frame was
|
||||
broadcast in both directions. Set automatically by
|
||||
``broadcast_frame()`` and ``broadcast_frame_instance()``.
|
||||
metadata: Dictionary for arbitrary frame metadata.
|
||||
transport_source: Name of the transport source that created this frame.
|
||||
transport_destination: Name of the transport destination for this frame.
|
||||
@@ -130,6 +134,7 @@ class Frame:
|
||||
id: int = field(init=False)
|
||||
name: str = field(init=False)
|
||||
pts: Optional[int] = field(init=False)
|
||||
broadcast_sibling_id: Optional[int] = field(init=False)
|
||||
metadata: Dict[str, Any] = field(init=False)
|
||||
transport_source: Optional[str] = field(init=False)
|
||||
transport_destination: Optional[str] = field(init=False)
|
||||
@@ -138,6 +143,7 @@ class Frame:
|
||||
self.id: int = obj_id()
|
||||
self.name: str = f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
self.pts: Optional[int] = None
|
||||
self.broadcast_sibling_id: Optional[int] = None
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
self.transport_source: Optional[str] = None
|
||||
self.transport_destination: Optional[str] = None
|
||||
@@ -1036,6 +1042,7 @@ class StartFrame(SystemFrame):
|
||||
Use `LLMUserAggregator`'s new `user_turn_strategies` parameter instead.
|
||||
|
||||
report_only_initial_ttfb: Whether to report only initial time-to-first-byte.
|
||||
tracing_context: Pipeline-scoped tracing context for span hierarchy.
|
||||
"""
|
||||
|
||||
audio_in_sample_rate: int = 16000
|
||||
@@ -1046,6 +1053,7 @@ class StartFrame(SystemFrame):
|
||||
enable_usage_metrics: bool = False
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list)
|
||||
report_only_initial_ttfb: bool = False
|
||||
tracing_context: Optional["TracingContext"] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -2142,6 +2150,20 @@ class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserIdleTimeoutUpdateFrame(SystemFrame):
|
||||
"""Frame for updating the user idle timeout at runtime.
|
||||
|
||||
Setting timeout to 0 disables idle detection. Setting a positive value
|
||||
enables it.
|
||||
|
||||
Parameters:
|
||||
timeout: The new idle timeout in seconds. 0 disables idle detection.
|
||||
"""
|
||||
|
||||
timeout: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADParamsUpdateFrame(ControlFrame):
|
||||
"""Frame for updating VAD parameters.
|
||||
|
||||
@@ -53,6 +53,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, F
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIObserverParams, RTVIProcessor
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager, TaskManager, TaskManagerParams
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
from pipecat.utils.tracing.tracing_context import TracingContext
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
|
||||
HEARTBEAT_SECS = 1.0
|
||||
@@ -290,10 +291,13 @@ class PipelineTask(BasePipelineTask):
|
||||
self._turn_tracking_observer: Optional[TurnTrackingObserver] = None
|
||||
self._user_bot_latency_observer: Optional[UserBotLatencyObserver] = None
|
||||
self._turn_trace_observer: Optional[TurnTraceObserver] = None
|
||||
self._tracing_context: Optional[TracingContext] = None
|
||||
if self._enable_turn_tracking:
|
||||
self._turn_tracking_observer = TurnTrackingObserver()
|
||||
observers.append(self._turn_tracking_observer)
|
||||
if self._enable_tracing and self._turn_tracking_observer:
|
||||
# Create pipeline-scoped tracing context
|
||||
self._tracing_context = TracingContext()
|
||||
# Create latency observer for tracing
|
||||
self._user_bot_latency_observer = UserBotLatencyObserver()
|
||||
observers.append(self._user_bot_latency_observer)
|
||||
@@ -303,6 +307,7 @@ class PipelineTask(BasePipelineTask):
|
||||
latency_tracker=self._user_bot_latency_observer,
|
||||
conversation_id=self._conversation_id,
|
||||
additional_span_attributes=self._additional_span_attributes,
|
||||
tracing_context=self._tracing_context,
|
||||
)
|
||||
observers.append(self._turn_trace_observer)
|
||||
|
||||
@@ -813,6 +818,7 @@ class PipelineTask(BasePipelineTask):
|
||||
enable_usage_metrics=self._params.enable_usage_metrics,
|
||||
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
|
||||
interruption_strategies=self._params.interruption_strategies,
|
||||
tracing_context=self._tracing_context,
|
||||
)
|
||||
start_frame.metadata = self._create_start_metadata()
|
||||
await self._pipeline.queue_frame(start_frame)
|
||||
|
||||
@@ -92,9 +92,9 @@ class LLMUserAggregatorParams:
|
||||
user_mute_strategies: List of user mute strategies.
|
||||
user_turn_stop_timeout: Time in seconds to wait before considering the
|
||||
user's turn finished.
|
||||
user_idle_timeout: Optional timeout in seconds for detecting user idle state.
|
||||
If set, the aggregator will emit an `on_user_turn_idle` event when the user
|
||||
has been idle (not speaking) for this duration. Set to None to disable
|
||||
user_idle_timeout: Timeout in seconds for detecting user idle state.
|
||||
The aggregator will emit an `on_user_turn_idle` event when the user
|
||||
has been idle (not speaking) for this duration. Set to 0 to disable
|
||||
idle detection.
|
||||
vad_analyzer: Voice Activity Detection analyzer instance.
|
||||
filter_incomplete_user_turns: Whether to filter out incomplete user turns.
|
||||
@@ -109,7 +109,7 @@ class LLMUserAggregatorParams:
|
||||
user_turn_strategies: Optional[UserTurnStrategies] = None
|
||||
user_mute_strategies: List[BaseUserMuteStrategy] = field(default_factory=list)
|
||||
user_turn_stop_timeout: float = 5.0
|
||||
user_idle_timeout: Optional[float] = None
|
||||
user_idle_timeout: float = 0
|
||||
vad_analyzer: Optional[VADAnalyzer] = None
|
||||
filter_incomplete_user_turns: bool = False
|
||||
user_turn_completion_config: Optional[UserTurnCompletionConfig] = None
|
||||
@@ -404,15 +404,10 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
"on_user_turn_stop_timeout", self._on_user_turn_stop_timeout
|
||||
)
|
||||
|
||||
# Optional user idle controller
|
||||
self._user_idle_controller: Optional[UserIdleController] = None
|
||||
if self._params.user_idle_timeout:
|
||||
self._user_idle_controller = UserIdleController(
|
||||
user_idle_timeout=self._params.user_idle_timeout
|
||||
)
|
||||
self._user_idle_controller.add_event_handler(
|
||||
"on_user_turn_idle", self._on_user_turn_idle
|
||||
)
|
||||
self._user_idle_controller = UserIdleController(
|
||||
user_idle_timeout=self._params.user_idle_timeout
|
||||
)
|
||||
self._user_idle_controller.add_event_handler("on_user_turn_idle", self._on_user_turn_idle)
|
||||
|
||||
# VAD controller
|
||||
self._vad_controller: Optional[VADController] = None
|
||||
@@ -489,8 +484,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
await self._user_turn_controller.process_frame(frame)
|
||||
|
||||
if self._user_idle_controller:
|
||||
await self._user_idle_controller.process_frame(frame)
|
||||
await self._user_idle_controller.process_frame(frame)
|
||||
|
||||
async def push_aggregation(self) -> str:
|
||||
"""Push the current aggregation."""
|
||||
@@ -507,8 +501,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
async def _start(self, frame: StartFrame):
|
||||
await self._user_turn_controller.setup(self.task_manager)
|
||||
|
||||
if self._user_idle_controller:
|
||||
await self._user_idle_controller.setup(self.task_manager)
|
||||
await self._user_idle_controller.setup(self.task_manager)
|
||||
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.setup(self.task_manager)
|
||||
@@ -541,14 +534,19 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
async def _cleanup(self):
|
||||
await self._user_turn_controller.cleanup()
|
||||
|
||||
if self._user_idle_controller:
|
||||
await self._user_idle_controller.cleanup()
|
||||
await self._user_idle_controller.cleanup()
|
||||
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.cleanup()
|
||||
|
||||
async def _maybe_mute_frame(self, frame: Frame):
|
||||
# Lifecycle frames should never be muted and should not trigger mute
|
||||
# state changes. Evaluating mute strategies on StartFrame would
|
||||
# broadcast UserMuteStartedFrame before StartFrame reaches downstream
|
||||
# processors.
|
||||
if isinstance(frame, (StartFrame, EndFrame, CancelFrame)):
|
||||
return False
|
||||
|
||||
should_mute_frame = self._user_is_muted and isinstance(
|
||||
frame,
|
||||
(
|
||||
@@ -689,6 +687,8 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
@@ -705,6 +705,8 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
await self._user_idle_controller.process_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
await self._maybe_emit_user_turn_stopped(strategy)
|
||||
|
||||
async def _on_user_turn_stop_timeout(self, controller):
|
||||
@@ -1255,8 +1257,8 @@ class LLMContextAggregatorPair:
|
||||
self,
|
||||
context: LLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
user_params: Optional[LLMUserAggregatorParams] = None,
|
||||
assistant_params: Optional[LLMAssistantAggregatorParams] = None,
|
||||
):
|
||||
"""Initialize the LLM context aggregator pair.
|
||||
|
||||
@@ -1265,6 +1267,8 @@ class LLMContextAggregatorPair:
|
||||
user_params: Parameters for the user context aggregator.
|
||||
assistant_params: Parameters for the assistant context aggregator.
|
||||
"""
|
||||
user_params = user_params or LLMUserAggregatorParams()
|
||||
assistant_params = assistant_params or LLMAssistantAggregatorParams()
|
||||
self._user = LLMUserAggregator(context, params=user_params)
|
||||
self._assistant = LLMAssistantAggregator(context, params=assistant_params)
|
||||
|
||||
|
||||
@@ -52,8 +52,6 @@ from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMet
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
INTERRUPTION_COMPLETION_TIMEOUT = 2.0
|
||||
|
||||
|
||||
class FrameDirection(Enum):
|
||||
"""Direction of frame flow in the processing pipeline.
|
||||
@@ -419,27 +417,49 @@ class FrameProcessor(BaseObject):
|
||||
"""
|
||||
self._metrics.set_core_metrics_data(data)
|
||||
|
||||
async def start_ttfb_metrics(self):
|
||||
"""Start time-to-first-byte metrics collection."""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
await self._metrics.start_ttfb_metrics(self._report_only_initial_ttfb)
|
||||
async def start_ttfb_metrics(self, *, start_time: Optional[float] = None):
|
||||
"""Start time-to-first-byte metrics collection.
|
||||
|
||||
async def stop_ttfb_metrics(self):
|
||||
"""Stop time-to-first-byte metrics collection and push results."""
|
||||
Args:
|
||||
start_time: Optional timestamp to use as the start time. If None,
|
||||
uses the current time.
|
||||
"""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_ttfb_metrics()
|
||||
await self._metrics.start_ttfb_metrics(
|
||||
start_time=start_time, report_only_initial_ttfb=self._report_only_initial_ttfb
|
||||
)
|
||||
|
||||
async def stop_ttfb_metrics(self, *, end_time: Optional[float] = None):
|
||||
"""Stop time-to-first-byte metrics collection and push results.
|
||||
|
||||
Args:
|
||||
end_time: Optional timestamp to use as the end time. If None, uses
|
||||
the current time.
|
||||
"""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_ttfb_metrics(end_time=end_time)
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def start_processing_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
await self._metrics.start_processing_metrics()
|
||||
async def start_processing_metrics(self, *, start_time: Optional[float] = None):
|
||||
"""Start processing metrics collection.
|
||||
|
||||
async def stop_processing_metrics(self):
|
||||
"""Stop processing metrics collection and push results."""
|
||||
Args:
|
||||
start_time: Optional timestamp to use as the start time. If None,
|
||||
uses the current time.
|
||||
"""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_processing_metrics()
|
||||
await self._metrics.start_processing_metrics(start_time=start_time)
|
||||
|
||||
async def stop_processing_metrics(self, *, end_time: Optional[float] = None):
|
||||
"""Stop processing metrics collection and push results.
|
||||
|
||||
Args:
|
||||
end_time: Optional timestamp to use as the end time. If None, uses
|
||||
the current time.
|
||||
"""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_processing_metrics(end_time=end_time)
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -741,7 +761,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_push_frame", frame)
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
async def push_interruption_task_frame_and_wait(self, *, timeout: float = 5.0):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the
|
||||
@@ -750,9 +770,11 @@ class FrameProcessor(BaseObject):
|
||||
attached to both frames so the caller can wait until the interruption
|
||||
has fully traversed the pipeline. The event is set when the
|
||||
`InterruptionFrame` reaches the pipeline sink. If the frame does
|
||||
not complete within `INTERRUPTION_COMPLETION_TIMEOUT` seconds, a
|
||||
warning is logged periodically until it completes.
|
||||
not complete within the given timeout, a warning is logged and the
|
||||
event is forcibly set so the caller is unblocked.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for the interruption to complete.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
|
||||
@@ -760,19 +782,20 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(event=event), FrameDirection.UPSTREAM)
|
||||
|
||||
# Wait for the `InterruptionFrame` to complete and log a warning
|
||||
# periodically if it takes too long.
|
||||
# Wait for the `InterruptionFrame` to complete and log a warning if it
|
||||
# takes too long. If it does take too long make sure we unblock it,
|
||||
# otherwise we will hang here forever.
|
||||
while not event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=INTERRUPTION_COMPLETION_TIMEOUT)
|
||||
await asyncio.wait_for(event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"{self}: InterruptionFrame has not completed after"
|
||||
f" {INTERRUPTION_COMPLETION_TIMEOUT}s. Make sure"
|
||||
" InterruptionFrame.complete() is being called (e.g. if the"
|
||||
" frame is being blocked or consumed before reaching the"
|
||||
" pipeline sink)."
|
||||
f" {timeout}s. Make sure InterruptionFrame.complete()"
|
||||
" is being called (e.g. if the frame is being blocked"
|
||||
" or consumed before reaching the pipeline sink)."
|
||||
)
|
||||
event.set()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
|
||||
@@ -787,8 +810,12 @@ class FrameProcessor(BaseObject):
|
||||
frame_cls: The class of the frame to be broadcasted.
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
"""
|
||||
await self.push_frame(frame_cls(**kwargs))
|
||||
await self.push_frame(frame_cls(**kwargs), FrameDirection.UPSTREAM)
|
||||
downstream_frame = frame_cls(**kwargs)
|
||||
upstream_frame = frame_cls(**kwargs)
|
||||
downstream_frame.broadcast_sibling_id = upstream_frame.id
|
||||
upstream_frame.broadcast_sibling_id = downstream_frame.id
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
async def broadcast_frame_instance(self, frame: Frame):
|
||||
"""Broadcasts a frame instance upstream and downstream.
|
||||
@@ -812,15 +839,18 @@ class FrameProcessor(BaseObject):
|
||||
if not f.init and f.name not in ("id", "name")
|
||||
}
|
||||
|
||||
new_frame = frame_cls(**init_fields)
|
||||
downstream_frame = frame_cls(**init_fields)
|
||||
for k, v in extra_fields.items():
|
||||
setattr(new_frame, k, v)
|
||||
await self.push_frame(new_frame)
|
||||
setattr(downstream_frame, k, v)
|
||||
|
||||
new_frame = frame_cls(**init_fields)
|
||||
upstream_frame = frame_cls(**init_fields)
|
||||
for k, v in extra_fields.items():
|
||||
setattr(new_frame, k, v)
|
||||
await self.push_frame(new_frame, FrameDirection.UPSTREAM)
|
||||
setattr(upstream_frame, k, v)
|
||||
|
||||
downstream_frame.broadcast_sibling_id = upstream_frame.id
|
||||
upstream_frame.broadcast_sibling_id = downstream_frame.id
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
async def __start(self, frame: StartFrame):
|
||||
"""Handle the start frame to initialize processor state.
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import (
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
@@ -1026,6 +1027,11 @@ class RTVIObserverParams:
|
||||
metrics_enabled: Indicates if metrics messages should be sent.
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
ignored_sources: List of frame processors whose frames should be silently ignored
|
||||
by this observer. Useful for suppressing RTVI messages from secondary pipeline
|
||||
branches (e.g. a silent evaluation LLM) that should not be visible to clients.
|
||||
Sources can also be added and removed dynamically via ``add_ignored_source()``
|
||||
and ``remove_ignored_source()``.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
@@ -1065,6 +1071,7 @@ class RTVIObserverParams:
|
||||
metrics_enabled: bool = True
|
||||
system_logs_enabled: bool = False
|
||||
errors_enabled: Optional[bool] = None
|
||||
ignored_sources: List[FrameProcessor] = field(default_factory=list)
|
||||
skip_aggregator_types: Optional[List[AggregationType | str]] = None
|
||||
bot_output_transforms: Optional[
|
||||
List[
|
||||
@@ -1110,12 +1117,17 @@ class RTVIObserver(BaseObserver):
|
||||
self._rtvi = rtvi
|
||||
self._params = params or RTVIObserverParams()
|
||||
|
||||
self._ignored_sources: Set[FrameProcessor] = set(self._params.ignored_sources)
|
||||
self._frames_seen = set()
|
||||
|
||||
self._bot_transcription = ""
|
||||
self._last_user_audio_level = 0
|
||||
self._last_bot_audio_level = 0
|
||||
|
||||
# Track bot speaking state for queuing aggregated text frames
|
||||
self._bot_is_speaking = False
|
||||
self._queued_aggregated_text_frames: List[AggregatedTextFrame] = []
|
||||
|
||||
if self._params.system_logs_enabled:
|
||||
self._system_logger_id = logger.add(self._logger_sink)
|
||||
|
||||
@@ -1166,6 +1178,31 @@ class RTVIObserver(BaseObserver):
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
def add_ignored_source(self, source: FrameProcessor):
|
||||
"""Ignore all frames pushed by the given processor.
|
||||
|
||||
Any frame whose source matches ``source`` will be silently skipped,
|
||||
preventing RTVI messages from being emitted for activity in that
|
||||
processor. Useful for suppressing events from secondary pipeline
|
||||
branches (e.g. a silent evaluation LLM) that should not be visible
|
||||
to clients.
|
||||
|
||||
Args:
|
||||
source: The frame processor to ignore.
|
||||
"""
|
||||
self._ignored_sources.add(source)
|
||||
|
||||
def remove_ignored_source(self, source: FrameProcessor):
|
||||
"""Stop ignoring frames pushed by the given processor.
|
||||
|
||||
Reverses a previous call to ``add_ignored_source()``. If ``source``
|
||||
was not previously ignored this is a no-op.
|
||||
|
||||
Args:
|
||||
source: The frame processor to stop ignoring.
|
||||
"""
|
||||
self._ignored_sources.discard(source)
|
||||
|
||||
def _get_function_call_report_level(self, function_name: str) -> RTVIFunctionCallReportLevel:
|
||||
"""Get the report level for a specific function call.
|
||||
|
||||
@@ -1216,10 +1253,13 @@ class RTVIObserver(BaseObserver):
|
||||
frame = data.frame
|
||||
direction = data.direction
|
||||
|
||||
# Only process downstream frames. Some frames are broadcast in both
|
||||
# directions (e.g. UserStartedSpeakingFrame, FunctionCallResultFrame),
|
||||
# and we only want to send one RTVI message per event.
|
||||
if direction != FrameDirection.DOWNSTREAM:
|
||||
# Frames from explicitly ignored sources are always skipped.
|
||||
if self._ignored_sources and src in self._ignored_sources:
|
||||
return
|
||||
|
||||
# For broadcast frames (pushed in both directions), only process
|
||||
# the downstream copy to avoid sending duplicate RTVI messages.
|
||||
if frame.broadcast_sibling_id is not None and direction != FrameDirection.DOWNSTREAM:
|
||||
return
|
||||
|
||||
# If we have already seen this frame, let's skip it.
|
||||
@@ -1384,17 +1424,30 @@ class RTVIObserver(BaseObserver):
|
||||
|
||||
async def _handle_bot_speaking(self, frame: Frame):
|
||||
"""Handle bot speaking event frames."""
|
||||
message = None
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
message = RTVIBotStartedSpeakingMessage()
|
||||
await self.send_rtvi_message(message)
|
||||
# Flush any queued aggregated text frames
|
||||
for queued_frame in self._queued_aggregated_text_frames:
|
||||
await self._send_aggregated_llm_text(queued_frame)
|
||||
self._queued_aggregated_text_frames.clear()
|
||||
self._bot_is_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
message = RTVIBotStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self.send_rtvi_message(message)
|
||||
self._bot_is_speaking = False
|
||||
|
||||
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
|
||||
"""Handle aggregated LLM text output frames."""
|
||||
if self._bot_is_speaking:
|
||||
# Bot has already started speaking, send directly
|
||||
await self._send_aggregated_llm_text(frame)
|
||||
else:
|
||||
# Bot hasn't started speaking yet, queue the frame
|
||||
self._queued_aggregated_text_frames.append(frame)
|
||||
|
||||
async def _send_aggregated_llm_text(self, frame: AggregatedTextFrame):
|
||||
"""Send aggregated LLM text messages."""
|
||||
# Skip certain aggregator types if configured to do so.
|
||||
if (
|
||||
self._params.skip_aggregator_types
|
||||
|
||||
@@ -107,49 +107,70 @@ class FrameProcessorMetrics(BaseObject):
|
||||
"""
|
||||
self._core_metrics_data = MetricsData(processor=name)
|
||||
|
||||
async def start_ttfb_metrics(self, report_only_initial_ttfb):
|
||||
async def start_ttfb_metrics(
|
||||
self, *, start_time: Optional[float] = None, report_only_initial_ttfb: bool
|
||||
):
|
||||
"""Start measuring time-to-first-byte (TTFB).
|
||||
|
||||
Args:
|
||||
start_time: Optional timestamp to use as the start time. If None,
|
||||
uses the current time.
|
||||
report_only_initial_ttfb: Whether to report only the first TTFB measurement.
|
||||
"""
|
||||
if self._should_report_ttfb:
|
||||
self._start_ttfb_time = time.time()
|
||||
self._start_ttfb_time = start_time or time.time()
|
||||
self._last_ttfb_time = 0
|
||||
self._should_report_ttfb = not report_only_initial_ttfb
|
||||
|
||||
async def stop_ttfb_metrics(self):
|
||||
async def stop_ttfb_metrics(self, *, end_time: Optional[float] = None):
|
||||
"""Stop TTFB measurement and generate metrics frame.
|
||||
|
||||
Args:
|
||||
end_time: Optional timestamp to use as the end time. If None, uses
|
||||
the current time.
|
||||
|
||||
Returns:
|
||||
MetricsFrame containing TTFB data, or None if not measuring.
|
||||
"""
|
||||
if self._start_ttfb_time == 0:
|
||||
return None
|
||||
|
||||
self._last_ttfb_time = time.time() - self._start_ttfb_time
|
||||
logger.debug(f"{self._processor_name()} TTFB: {self._last_ttfb_time}")
|
||||
end_time = end_time or time.time()
|
||||
|
||||
self._last_ttfb_time = end_time - self._start_ttfb_time
|
||||
logger.debug(f"{self._processor_name()} TTFB: {self._last_ttfb_time:.3f}s")
|
||||
ttfb = TTFBMetricsData(
|
||||
processor=self._processor_name(), value=self._last_ttfb_time, model=self._model_name()
|
||||
)
|
||||
self._start_ttfb_time = 0
|
||||
return MetricsFrame(data=[ttfb])
|
||||
|
||||
async def start_processing_metrics(self):
|
||||
"""Start measuring processing time."""
|
||||
self._start_processing_time = time.time()
|
||||
async def start_processing_metrics(self, *, start_time: Optional[float] = None):
|
||||
"""Start measuring processing time.
|
||||
|
||||
async def stop_processing_metrics(self):
|
||||
Args:
|
||||
start_time: Optional timestamp to use as the start time. If None,
|
||||
uses the current time.
|
||||
"""
|
||||
self._start_processing_time = start_time or time.time()
|
||||
|
||||
async def stop_processing_metrics(self, *, end_time: Optional[float] = None):
|
||||
"""Stop processing time measurement and generate metrics frame.
|
||||
|
||||
Args:
|
||||
end_time: Optional timestamp to use as the end time. If None, uses
|
||||
the current time.
|
||||
|
||||
Returns:
|
||||
MetricsFrame containing processing duration data, or None if not measuring.
|
||||
"""
|
||||
if self._start_processing_time == 0:
|
||||
return None
|
||||
|
||||
value = time.time() - self._start_processing_time
|
||||
logger.debug(f"{self._processor_name()} processing time: {value}")
|
||||
end_time = end_time or time.time()
|
||||
|
||||
value = end_time - self._start_processing_time
|
||||
logger.debug(f"{self._processor_name()} processing time: {value:.3f}s")
|
||||
processing = ProcessingMetricsData(
|
||||
processor=self._processor_name(), value=value, model=self._model_name()
|
||||
)
|
||||
|
||||
@@ -44,6 +44,8 @@ class AIService(FrameProcessor):
|
||||
self._model_name: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._session_properties: Dict[str, Any] = {}
|
||||
self._tracing_enabled: bool = False
|
||||
self._tracing_context = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
@@ -72,7 +74,8 @@ class AIService(FrameProcessor):
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
pass
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
self._tracing_context = frame.tracing_context
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the AI service.
|
||||
|
||||
@@ -184,7 +184,7 @@ class AnthropicLLMService(LLMService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "claude-sonnet-4-5-20250929",
|
||||
model: str = "claude-sonnet-4-6",
|
||||
params: Optional[InputParams] = None,
|
||||
client=None,
|
||||
retry_timeout_secs: Optional[float] = 5.0,
|
||||
@@ -195,7 +195,7 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key for authentication.
|
||||
model: Model name to use. Defaults to "claude-sonnet-4-5-20250929".
|
||||
model: Model name to use. Defaults to "claude-sonnet-4-6".
|
||||
params: Optional model parameters for inference.
|
||||
client: Optional custom Anthropic client instance.
|
||||
retry_timeout_secs: Request timeout in seconds for retry logic.
|
||||
|
||||
@@ -147,7 +147,6 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._context_id = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -254,7 +253,7 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Async")
|
||||
# Close all contexts and the socket
|
||||
if self._context_id:
|
||||
if self.has_active_audio_context():
|
||||
await self._websocket.send(json.dumps({"terminate": True}))
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from Async")
|
||||
@@ -262,7 +261,7 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
await self.remove_active_audio_context()
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
@@ -272,10 +271,11 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio."""
|
||||
if not self._context_id or not self._websocket:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = self._build_msg(text=" ", context_id=self._context_id, force=True)
|
||||
msg = self._build_msg(text=" ", context_id=context_id, force=True)
|
||||
await self._websocket.send(msg)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -303,11 +303,11 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
|
||||
# Check if this message belongs to the current context.
|
||||
if not self.audio_context_available(received_ctx_id):
|
||||
if self._context_id == received_ctx_id:
|
||||
if self.get_active_audio_context_id() == received_ctx_id:
|
||||
logger.debug(
|
||||
f"Received a delayed message, recreating the context: {self._context_id}"
|
||||
f"Received a delayed message, recreating the context: {received_ctx_id}"
|
||||
)
|
||||
await self.create_audio_context(self._context_id)
|
||||
await self.create_audio_context(received_ctx_id)
|
||||
else:
|
||||
# This can happen if a message is received _after_ we have closed a context
|
||||
# due to user interruption but _before_ the `isFinal` message for the context
|
||||
@@ -328,10 +328,11 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
if self._context_id:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
keepalive_message = {
|
||||
"transcript": " ",
|
||||
"context_id": self._context_id,
|
||||
"context_id": context_id,
|
||||
}
|
||||
logger.trace("Sending keepalive message")
|
||||
else:
|
||||
@@ -347,19 +348,16 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
if self._context_id and self._websocket:
|
||||
if context_id and self._websocket:
|
||||
try:
|
||||
await self._websocket.send(
|
||||
json.dumps(
|
||||
{"context_id": self._context_id, "close_context": True, "transcript": ""}
|
||||
)
|
||||
json.dumps({"context_id": context_id, "close_context": True, "transcript": ""})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
self._context_id = None
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -379,15 +377,13 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
if not self._context_id:
|
||||
self._context_id = context_id
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
|
||||
msg = self._build_msg(text=text, force=True, context_id=context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
|
||||
@@ -306,7 +306,6 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._context_id = None
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
@@ -429,7 +428,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
msg = {
|
||||
"transcript": text,
|
||||
"continue": continue_transcript,
|
||||
"context_id": self._context_id,
|
||||
"context_id": self.get_active_audio_context_id(),
|
||||
"model_id": self.model_name,
|
||||
"voice": voice_config,
|
||||
"output_format": self._settings["output_format"],
|
||||
@@ -522,7 +521,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -532,21 +531,22 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
if self._context_id:
|
||||
cancel_msg = json.dumps({"context_id": self._context_id, "cancel": True})
|
||||
if context_id:
|
||||
cancel_msg = json.dumps({"context_id": context_id, "cancel": True})
|
||||
await self._get_websocket().send(cancel_msg)
|
||||
self._context_id = None
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
if not self._context_id or not self._websocket:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = self._build_msg(text="", continue_transcript=False)
|
||||
await self._websocket.send(msg)
|
||||
self._context_id = None
|
||||
self.reset_active_audio_context()
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
@@ -578,7 +578,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self.push_frame(TTSStoppedFrame(context_id=ctx_id))
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
self._context_id = None
|
||||
self.reset_active_audio_context()
|
||||
else:
|
||||
await self.push_error(error_msg=f"Error, unknown message type: {msg}")
|
||||
|
||||
@@ -607,11 +607,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self._context_id:
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
await self.create_audio_context(self._context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
|
||||
|
||||
@@ -73,6 +73,20 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
Provides real-time speech recognition using Deepgram's WebSocket API with Flux capabilities.
|
||||
Supports configurable models, VAD events, and various audio processing options
|
||||
including advanced turn detection and EagerEndOfTurn events for improved conversational AI performance.
|
||||
|
||||
Event handlers available (in addition to WebsocketSTTService events):
|
||||
|
||||
- on_speech_started(service): Deepgram detected start of speech
|
||||
- on_utterance_end(service): Deepgram detected end of utterance
|
||||
- on_end_of_turn(service): Deepgram detected end of turn (EOT)
|
||||
- on_eager_end_of_turn(service): Deepgram predicted end of turn (EagerEOT)
|
||||
- on_turn_resumed(service): User resumed speaking after EagerEOT
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_end_of_turn")
|
||||
async def on_end_of_turn(service):
|
||||
...
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
|
||||
@@ -50,6 +50,17 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
Provides real-time speech recognition using Deepgram's WebSocket API.
|
||||
Supports configurable models, languages, and various audio processing options.
|
||||
|
||||
Event handlers available (in addition to STTService events):
|
||||
|
||||
- on_speech_started(service): Deepgram detected start of speech
|
||||
- on_utterance_end(service): Deepgram detected end of utterance
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_speech_started")
|
||||
async def on_speech_started(service):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -368,7 +368,6 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
return
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
speech_final = parsed.get("speech_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
@@ -376,8 +375,12 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final and speech_final:
|
||||
# Final transcription
|
||||
if is_final:
|
||||
# Check if this response is from a finalize() call.
|
||||
# Only mark as finalized when both we requested it AND Deepgram confirms it.
|
||||
from_finalize = parsed.get("from_finalize", False)
|
||||
if from_finalize:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
@@ -435,10 +438,12 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# Send finalize message to Deepgram when user stops speaking
|
||||
# This tells Deepgram to flush any remaining audio and return final results
|
||||
# https://developers.deepgram.com/docs/finalize
|
||||
# Mark that we're awaiting a from_finalize response
|
||||
self.request_finalize()
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
|
||||
315
src/pipecat/services/deepgram/tts_sagemaker.py
Normal file
315
src/pipecat/services/deepgram/tts_sagemaker.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat TTS service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time speech synthesis with support for interruptions and
|
||||
streaming audio output.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
class DeepgramSageMakerTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech synthesis using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
audio generation with support for interruptions via the Clear message.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- ``pipecat-ai[sagemaker]`` installed
|
||||
|
||||
Example::
|
||||
|
||||
tts = DeepgramSageMakerTTSService(
|
||||
endpoint_name="my-deepgram-tts-endpoint",
|
||||
region="us-east-2",
|
||||
voice="aura-2-helena-en",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
|
||||
deployed (e.g., "my-deepgram-tts-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
self._encoding = encoding
|
||||
self.set_voice(voice)
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._ttfb_started: bool = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._ttfb_started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram TTS query string, creates the BiDi client,
|
||||
starts the streaming session, and launches a background task for processing
|
||||
responses.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram TTS on SageMaker...")
|
||||
|
||||
query_string = (
|
||||
f"model={self._voice_id}&encoding={self._encoding}&sample_rate={self.sample_rate}"
|
||||
)
|
||||
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/speak",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._client.start_session()
|
||||
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
logger.debug("Connected to Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a Close message to Deepgram, cancels the response processing task,
|
||||
and closes the BiDi session. Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
|
||||
|
||||
try:
|
||||
await self._client.send_json({"type": "Close"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send Close message: {e}")
|
||||
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram TTS on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream. Attempts to decode
|
||||
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
|
||||
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
|
||||
a TTSAudioRawFrame downstream.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
payload = result.value.bytes_
|
||||
|
||||
# Try to decode as JSON control message first
|
||||
try:
|
||||
response_data = payload.decode("utf-8")
|
||||
parsed = json.loads(response_data)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {parsed}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {parsed}")
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {parsed}")
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: "
|
||||
f"{parsed.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {parsed}")
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not JSON — treat as raw audio bytes
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
payload,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=self._context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("TTS response processor stopped")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._ttfb_started = False
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Clear"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Flush"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram TTS on SageMaker.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
context_id: The context ID for tracking audio frames.
|
||||
|
||||
Yields:
|
||||
Frame: TTSStartedFrame, then None (audio comes asynchronously via
|
||||
the response processor).
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -342,7 +342,6 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
# Context management for v1 multi API
|
||||
self._context_id = None
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
|
||||
@@ -410,18 +409,19 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
elif voice_settings_changed and self._context_id:
|
||||
elif voice_settings_changed and self.has_active_audio_context():
|
||||
# Voice settings can be updated by closing current context
|
||||
# so new one gets created with updated voice settings
|
||||
logger.debug(f"Voice settings changed, closing current context to apply changes")
|
||||
context_id = self.get_active_audio_context_id()
|
||||
try:
|
||||
if self._websocket:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
json.dumps({"context_id": context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._context_id = None
|
||||
self.reset_active_audio_context()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the ElevenLabs TTS service.
|
||||
@@ -453,10 +453,11 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
if not self._context_id or not self._websocket:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = {"context_id": self._context_id, "flush": True}
|
||||
msg = {"context_id": context_id, "flush": True}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -469,7 +470,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)], self._context_id)
|
||||
await self.add_word_timestamps([("Reset", 0)], self.get_active_audio_context_id())
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
@@ -544,14 +545,14 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from ElevenLabs")
|
||||
# Close all contexts and the socket
|
||||
if self._context_id:
|
||||
if self.has_active_audio_context():
|
||||
await self._websocket.send(json.dumps({"close_socket": True}))
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from ElevenLabs")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -562,11 +563,12 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
if self._context_id and self._websocket:
|
||||
logger.trace(f"Closing context {self._context_id} due to interruption")
|
||||
if context_id and self._websocket:
|
||||
logger.trace(f"Closing context {context_id} due to interruption")
|
||||
try:
|
||||
# ElevenLabs requires that Pipecat manages the contexts and closes them
|
||||
# when they're not longer in use. Since an InterruptionFrame is pushed
|
||||
@@ -575,11 +577,10 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
# Note: We do not need to call remove_audio_context here, as the context is
|
||||
# automatically reset when super ()._handle_interruption is called.
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
json.dumps({"context_id": context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._context_id = None
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
@@ -599,11 +600,11 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
# Check if this message belongs to the current context.
|
||||
if not self.audio_context_available(received_ctx_id):
|
||||
if self._context_id == received_ctx_id:
|
||||
if self.get_active_audio_context_id() == received_ctx_id:
|
||||
logger.debug(
|
||||
f"Received a delayed message, recreating the context: {self._context_id}"
|
||||
f"Received a delayed message, recreating the context: {received_ctx_id}"
|
||||
)
|
||||
await self.create_audio_context(self._context_id)
|
||||
await self.create_audio_context(received_ctx_id)
|
||||
else:
|
||||
# This can happen if a message is received _after_ we have closed a context
|
||||
# due to user interruption but _before_ the `isFinal` message for the context
|
||||
@@ -656,13 +657,14 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
if self._context_id:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
# Send keepalive with context ID to keep the connection alive
|
||||
keepalive_message = {
|
||||
"text": "",
|
||||
"context_id": self._context_id,
|
||||
"context_id": context_id,
|
||||
}
|
||||
logger.trace(f"Sending keepalive for context {self._context_id}")
|
||||
logger.trace(f"Sending keepalive for context {context_id}")
|
||||
else:
|
||||
# It's possible to have a user interruption which clears the context
|
||||
# without generating a new TTS response. In this case, we'll just send
|
||||
@@ -676,8 +678,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _send_text(self, text: str):
|
||||
"""Send text to the WebSocket for synthesis."""
|
||||
if self._websocket and self._context_id:
|
||||
msg = {"text": text, "context_id": self._context_id}
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if self._websocket and context_id:
|
||||
msg = {"text": text, "context_id": context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
@traced_tts
|
||||
@@ -698,31 +701,27 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
# If a context ID does not exist, use the provided one.
|
||||
# If an ID exists, that means the Pipeline doesn't allow
|
||||
# user interruptions, so continue using the current ID.
|
||||
# When interruptions are allowed, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
self._context_id = context_id
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
# Initialize context with voice settings and pronunciation dictionaries
|
||||
msg = {"text": " ", "context_id": self._context_id}
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
if self._pronunciation_dictionary_locators:
|
||||
msg["pronunciation_dictionary_locators"] = [
|
||||
locator.model_dump() for locator in self._pronunciation_dictionary_locators
|
||||
]
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
# Initialize context with voice settings and pronunciation dictionaries
|
||||
msg = {"text": " ", "context_id": context_id}
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
if self._pronunciation_dictionary_locators:
|
||||
msg["pronunciation_dictionary_locators"] = [
|
||||
locator.model_dump()
|
||||
for locator in self._pronunciation_dictionary_locators
|
||||
]
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {context_id}")
|
||||
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
@@ -16,13 +16,14 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService
|
||||
from pipecat.services.tts_service import AudioContextWordTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -37,7 +38,7 @@ except ModuleNotFoundError as e:
|
||||
SAMPLE_RATE = 48000
|
||||
|
||||
|
||||
class GradiumTTSService(InterruptibleWordTTSService):
|
||||
class GradiumTTSService(AudioContextWordTTSService):
|
||||
"""Text-to-Speech service using Gradium's websocket API."""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -71,9 +72,9 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
params: Additional configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
**kwargs,
|
||||
@@ -95,7 +96,6 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
|
||||
# State tracking
|
||||
self._receive_task = None
|
||||
self._current_context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -126,7 +126,11 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Gradium API."""
|
||||
return {"text": text, "type": "text"}
|
||||
msg = {"text": text, "type": "text"}
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
msg["client_req_id"] = context_id
|
||||
return msg
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish websocket connection.
|
||||
@@ -197,6 +201,7 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
"type": "setup",
|
||||
"output_format": "pcm",
|
||||
"voice_id": self._voice_id,
|
||||
"close_ws_on_eos": False,
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
@@ -223,6 +228,7 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -234,18 +240,35 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
msg = {"type": "end_of_stream", "client_req_id": context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
self.reset_active_audio_context()
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by resetting context state.
|
||||
|
||||
The parent AudioContextTTSService._handle_interruption() cancels the audio context
|
||||
task and creates a new one. We reset _context_id so the next run_tts() creates a
|
||||
fresh context. No websocket reconnection needed — audio from the old client_req_id
|
||||
will be silently dropped since the audio context no longer exists.
|
||||
|
||||
Args:
|
||||
frame: The interruption frame.
|
||||
direction: The direction of the frame.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
"""Process incoming websocket messages, demultiplexing by client_req_id."""
|
||||
# TODO(laurent): This should not be necessary as it should happen when
|
||||
# receiving the messages but this does not seem to always be the case
|
||||
# and that may lead to a busy polling loop.
|
||||
@@ -253,41 +276,35 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
raise ConnectionClosedOK(None, None)
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
ctx_id = msg.get("client_req_id")
|
||||
|
||||
if msg["type"] == "audio":
|
||||
# Process audio chunk
|
||||
if not ctx_id or not self.audio_context_available(ctx_id):
|
||||
continue
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self._current_context_id,
|
||||
context_id=ctx_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(ctx_id, frame)
|
||||
|
||||
elif msg["type"] == "text":
|
||||
if self._current_context_id:
|
||||
await self.add_word_timestamps(
|
||||
[(msg["text"], msg["start_s"])], self._current_context_id
|
||||
)
|
||||
if ctx_id and self.audio_context_available(ctx_id):
|
||||
await self.add_word_timestamps([(msg["text"], msg["start_s"])], ctx_id)
|
||||
|
||||
elif msg["type"] == "end_of_stream":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
if ctx_id and self.audio_context_available(ctx_id):
|
||||
await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)], ctx_id)
|
||||
await self.remove_audio_context(ctx_id)
|
||||
await self.stop_all_metrics()
|
||||
|
||||
elif msg["type"] == "error":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(TTSStoppedFrame(context_id=ctx_id))
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
await self.push_error(error_msg=f"Error: {msg.get('message', msg)}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -300,16 +317,17 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
_state = self._websocket.state if self._websocket is not None else None
|
||||
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
self._websocket = None
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
self._current_context_id = context_id
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
|
||||
@@ -216,7 +216,7 @@ class SessionProperties(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
instructions: Optional[str] = None
|
||||
voice: Optional[GrokVoice] = "Ara"
|
||||
voice: Optional[GrokVoice | str] = "Ara"
|
||||
turn_detection: Optional[TurnDetection] = Field(
|
||||
default_factory=lambda: TurnDetection(type="server_vad")
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
@@ -65,10 +65,12 @@ class InworldHttpTTSService(WordTTSService):
|
||||
Parameters:
|
||||
temperature: Temperature for speech synthesis.
|
||||
speaking_rate: Speaking rate for speech synthesis.
|
||||
timestamp_transport_strategy: The strategy to use for timestamp transport.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
speaking_rate: Optional[float] = None
|
||||
timestamp_transport_strategy: Optional[Literal["ASYNC", "SYNC"]] = "ASYNC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -128,6 +130,8 @@ class InworldHttpTTSService(WordTTSService):
|
||||
self._settings["temperature"] = params.temperature
|
||||
if params.speaking_rate is not None:
|
||||
self._settings["audioConfig"]["speakingRate"] = params.speaking_rate
|
||||
if params.timestamp_transport_strategy is not None:
|
||||
self._settings["timestampTransportStrategy"] = params.timestamp_transport_strategy
|
||||
|
||||
self._cumulative_time = 0.0
|
||||
|
||||
@@ -240,6 +244,8 @@ class InworldHttpTTSService(WordTTSService):
|
||||
|
||||
# Use WORD timestamps for simplicity and correct spacing/capitalization
|
||||
payload["timestampType"] = self._timestamp_type
|
||||
if "timestampTransportStrategy" in self._settings:
|
||||
payload["timestampTransportStrategy"] = self._settings["timestampTransportStrategy"]
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
headers = {
|
||||
@@ -427,6 +433,7 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
flushing of buffered text to achieve minimal latency while
|
||||
maintaining high quality audio output. If None (default),
|
||||
automatically set based on aggregate_sentences.
|
||||
timestamp_transport_strategy: The strategy to use for timestamp transport.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
@@ -434,7 +441,8 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
apply_text_normalization: Optional[str] = None
|
||||
max_buffer_delay_ms: Optional[int] = None
|
||||
buffer_char_threshold: Optional[int] = None
|
||||
auto_mode: Optional[bool] = None
|
||||
auto_mode: Optional[bool] = True
|
||||
timestamp_transport_strategy: Optional[Literal["ASYNC", "SYNC"]] = "ASYNC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -494,6 +502,8 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
self._settings["audioConfig"]["speakingRate"] = params.speaking_rate
|
||||
if params.apply_text_normalization is not None:
|
||||
self._settings["applyTextNormalization"] = params.apply_text_normalization
|
||||
if params.timestamp_transport_strategy is not None:
|
||||
self._settings["timestampTransportStrategy"] = params.timestamp_transport_strategy
|
||||
|
||||
if params.auto_mode is not None:
|
||||
self._settings["autoMode"] = params.auto_mode
|
||||
@@ -507,7 +517,6 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._context_id = None
|
||||
|
||||
# Track cumulative time across generations for monotonic timestamps within a turn.
|
||||
# When auto_mode is enabled, the server controls generations and timestamps reset
|
||||
@@ -563,9 +572,10 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
keeping the context open for subsequent text. The context is only
|
||||
closed on interruption, disconnect, or end of session.
|
||||
"""
|
||||
if self._context_id and self._websocket:
|
||||
logger.trace(f"Flushing audio for context {self._context_id}")
|
||||
await self._send_flush(self._context_id)
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id and self._websocket:
|
||||
logger.trace(f"Flushing audio for context {context_id}")
|
||||
await self._send_flush(context_id)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
@@ -630,7 +640,7 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
frame: The interruption frame.
|
||||
direction: The direction of the interruption.
|
||||
"""
|
||||
old_context_id = self._context_id
|
||||
old_context_id = self.get_active_audio_context_id()
|
||||
logger.trace(f"{self}: Handling interruption, old context: {old_context_id}")
|
||||
|
||||
await super()._handle_interruption(frame, direction)
|
||||
@@ -642,7 +652,6 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
|
||||
self._context_id = None
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
logger.trace(f"{self}: Interruption handled, context reset to None")
|
||||
@@ -726,9 +735,10 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Inworld WebSocket TTS")
|
||||
if self._context_id:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
try:
|
||||
await self._send_close_context(self._context_id)
|
||||
await self._send_close_context(context_id)
|
||||
except Exception:
|
||||
pass
|
||||
await self._websocket.close()
|
||||
@@ -736,7 +746,7 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
@@ -762,7 +772,7 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
]
|
||||
logger.debug(
|
||||
f"{self}: Received message types={msg_types}, ctx_id={ctx_id}, "
|
||||
f"current_ctx={self._context_id}, available={self.audio_context_available(ctx_id) if ctx_id else 'N/A'}"
|
||||
f"current_ctx={self.get_active_audio_context_id()}, available={self.audio_context_available(ctx_id) if ctx_id else 'N/A'}"
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
@@ -774,7 +784,9 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
# Handle "Context not found" error (code 5)
|
||||
# This can happen when a keepalive message is sent but no context is available.
|
||||
if error_code == 5 and "not found" in error_msg.lower():
|
||||
logger.debug(f"{self}: Context {ctx_id or self._context_id} not found.")
|
||||
logger.debug(
|
||||
f"{self}: Context {ctx_id or self.get_active_audio_context_id()} not found."
|
||||
)
|
||||
continue
|
||||
|
||||
# For other errors, push error frame
|
||||
@@ -789,11 +801,9 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
# If the context isn't available but matches our current context ID,
|
||||
# recreate it (handles race conditions during interruption recovery).
|
||||
if ctx_id and not self.audio_context_available(ctx_id):
|
||||
if self._context_id == ctx_id:
|
||||
logger.trace(
|
||||
f"{self}: Recreating audio context for current context: {self._context_id}"
|
||||
)
|
||||
await self.create_audio_context(self._context_id)
|
||||
if self.get_active_audio_context_id() == ctx_id:
|
||||
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
|
||||
await self.create_audio_context(ctx_id)
|
||||
else:
|
||||
# This is a message from an old/closed context - skip it
|
||||
logger.trace(f"{self}: Skipping message from unavailable context: {ctx_id}")
|
||||
@@ -815,12 +825,12 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
if ctx_id:
|
||||
await self.append_to_audio_context(ctx_id, frame)
|
||||
|
||||
# timestampInfo is inside audioChunk
|
||||
timestamp_info = audio_chunk.get("timestampInfo")
|
||||
if timestamp_info:
|
||||
word_times = self._calculate_word_times(timestamp_info)
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times, ctx_id)
|
||||
# timestampInfo is inside audioChunk
|
||||
timestamp_info = audio_chunk.get("timestampInfo")
|
||||
if timestamp_info:
|
||||
word_times = self._calculate_word_times(timestamp_info)
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times, ctx_id)
|
||||
|
||||
# Handle context created confirmation
|
||||
if "contextCreated" in result:
|
||||
@@ -839,8 +849,8 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
logger.trace(f"{self}: Context closed on server: {ctx_id}")
|
||||
await self.stop_ttfb_metrics()
|
||||
# Only reset if this is our current context
|
||||
if ctx_id == self._context_id:
|
||||
self._context_id = None
|
||||
if ctx_id == self.get_active_audio_context_id():
|
||||
self.reset_active_audio_context()
|
||||
if ctx_id and self.audio_context_available(ctx_id):
|
||||
await self.remove_audio_context(ctx_id)
|
||||
await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)], ctx_id)
|
||||
@@ -852,12 +862,13 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
if self._context_id:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
keepalive_message = {
|
||||
"send_text": {"text": ""},
|
||||
"contextId": self._context_id,
|
||||
"contextId": context_id,
|
||||
}
|
||||
logger.trace(f"Sending keepalive for context {self._context_id}")
|
||||
logger.trace(f"Sending keepalive for context {context_id}")
|
||||
else:
|
||||
keepalive_message = {"send_text": {"text": ""}}
|
||||
logger.trace("Sending keepalive without context")
|
||||
@@ -884,6 +895,10 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
create_config["applyTextNormalization"] = self._settings["applyTextNormalization"]
|
||||
if "autoMode" in self._settings:
|
||||
create_config["autoMode"] = self._settings["autoMode"]
|
||||
if "timestampTransportStrategy" in self._settings:
|
||||
create_config["timestampTransportStrategy"] = self._settings[
|
||||
"timestampTransportStrategy"
|
||||
]
|
||||
|
||||
# Set buffer settings for timely audio generation.
|
||||
# Use provided values or defaults that work well for streaming LLM output.
|
||||
@@ -942,20 +957,13 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
await self._send_context(context_id)
|
||||
|
||||
if not self._context_id:
|
||||
self._context_id = context_id
|
||||
logger.trace(f"{self}: Creating new context {self._context_id}")
|
||||
await self.create_audio_context(self._context_id)
|
||||
await self._send_context(self._context_id)
|
||||
elif not self.audio_context_available(self._context_id):
|
||||
# Context exists on server but local tracking was removed
|
||||
logger.trace(f"{self}: Recreating local audio context {self._context_id}")
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
await self._send_text(self._context_id, text)
|
||||
await self._send_text(context_id, text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -198,7 +198,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
self._skip_tts: Optional[bool] = None
|
||||
self._summary_task: Optional[asyncio.Task] = None
|
||||
|
||||
@@ -285,7 +284,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
await super().start(frame)
|
||||
if not self._run_in_parallel:
|
||||
await self._create_sequential_runner_task()
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the LLM service.
|
||||
|
||||
@@ -375,20 +375,29 @@ class BaseOpenAILLMService(LLMService):
|
||||
else self._stream_chat_completions_universal_context(context)
|
||||
)
|
||||
|
||||
# Ensure stream is closed on cancellation/exception to prevent socket
|
||||
# leaks. OpenAI's AsyncStream uses close(), async generators use aclose().
|
||||
# Ensure stream and its async iterator are closed on cancellation/exception
|
||||
# to prevent socket leaks and uvloop crashes. Closing the iterator first
|
||||
# cascades cleanup through nested async generators (httpx/httpcore internals),
|
||||
# preventing uvloop's broken asyncgen finalizer from firing on Python 3.12+
|
||||
# (MagicStack/uvloop#699).
|
||||
@asynccontextmanager
|
||||
async def _closing(stream):
|
||||
chunk_iter = stream.__aiter__()
|
||||
try:
|
||||
yield stream
|
||||
yield chunk_iter
|
||||
finally:
|
||||
if hasattr(stream, "aclose"):
|
||||
await stream.aclose()
|
||||
elif hasattr(stream, "close"):
|
||||
# Close the iterator first to cascade cleanup through
|
||||
# nested async generators (httpx/httpcore internals).
|
||||
if hasattr(chunk_iter, "aclose"):
|
||||
await chunk_iter.aclose()
|
||||
# Then close the stream to release HTTP resources.
|
||||
if hasattr(stream, "close"):
|
||||
await stream.close()
|
||||
elif hasattr(stream, "aclose"):
|
||||
await stream.aclose()
|
||||
|
||||
async with _closing(chunk_stream):
|
||||
async for chunk in chunk_stream:
|
||||
async with _closing(chunk_stream) as chunk_iter:
|
||||
async for chunk in chunk_iter:
|
||||
if chunk.usage:
|
||||
cached_tokens = (
|
||||
chunk.usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
@@ -13,6 +13,7 @@ supporting both WebSocket streaming and HTTP-based synthesis.
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
@@ -323,6 +324,20 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happen, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by stopping metrics and clearing request ID."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
@@ -25,8 +25,6 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import AudioContextWordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -70,6 +68,7 @@ class ResembleAITTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
reuse_context_id_within_turn=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,25 +81,39 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
Parameters:
|
||||
language: Language for synthesis. Defaults to English.
|
||||
speed_alpha: Speech speed multiplier. Defaults to 1.0.
|
||||
reduce_latency: Whether to reduce latency at potential quality cost.
|
||||
pause_between_brackets: Whether to add pauses between bracketed content.
|
||||
phonemize_between_brackets: Whether to phonemize bracketed content.
|
||||
segment: Text segmentation mode ("immediate", "bySentence", "never").
|
||||
repetition_penalty: Token repetition penalty (arcana only).
|
||||
temperature: Sampling temperature (arcana only).
|
||||
top_p: Cumulative probability threshold (arcana only).
|
||||
speed_alpha: Speech speed multiplier (mistv2 only).
|
||||
reduce_latency: Whether to reduce latency at potential quality cost (mistv2 only).
|
||||
pause_between_brackets: Whether to add pauses between bracketed content (mistv2 only).
|
||||
phonemize_between_brackets: Whether to phonemize bracketed content (mistv2 only).
|
||||
no_text_normalization: Whether to disable text normalization (mistv2 only).
|
||||
save_oovs: Whether to save out-of-vocabulary words (mistv2 only).
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
speed_alpha: Optional[float] = 1.0
|
||||
reduce_latency: Optional[bool] = False
|
||||
pause_between_brackets: Optional[bool] = False
|
||||
phonemize_between_brackets: Optional[bool] = False
|
||||
segment: Optional[str] = None
|
||||
# Arcana params
|
||||
repetition_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
# Mistv2 params
|
||||
speed_alpha: Optional[float] = None
|
||||
reduce_latency: Optional[bool] = None
|
||||
pause_between_brackets: Optional[bool] = None
|
||||
phonemize_between_brackets: Optional[bool] = None
|
||||
no_text_normalization: Optional[bool] = None
|
||||
save_oovs: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
url: str = "wss://users.rime.ai/ws2",
|
||||
model: str = "mistv2",
|
||||
url: str = "wss://users-ws.rime.ai/ws3",
|
||||
model: str = "arcana",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
@@ -142,29 +156,16 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
|
||||
|
||||
params = params or RimeTTSService.InputParams()
|
||||
self._params = params or RimeTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._model = model
|
||||
self._settings = {
|
||||
"speaker": voice_id,
|
||||
"modelId": model,
|
||||
"audioFormat": "pcm",
|
||||
"samplingRate": 0,
|
||||
"lang": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "eng",
|
||||
"speedAlpha": params.speed_alpha,
|
||||
"reduceLatency": params.reduce_latency,
|
||||
"pauseBetweenBrackets": json.dumps(params.pause_between_brackets),
|
||||
"phonemizeBetweenBrackets": json.dumps(params.phonemize_between_brackets),
|
||||
}
|
||||
self._settings = self._build_settings()
|
||||
|
||||
# State tracking
|
||||
self._context_id = None # Tracks current turn
|
||||
self._receive_task = None
|
||||
self._cumulative_time = 0 # Accumulates time across messages
|
||||
self._extra_msg_fields = {} # Extra fields for next message
|
||||
@@ -188,14 +189,60 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
return language_to_rime_language(language)
|
||||
|
||||
def _build_settings(self) -> dict:
|
||||
"""Build query params for the WebSocket URL based on the current model and params.
|
||||
|
||||
Returns:
|
||||
Dictionary of query parameters. Only explicitly-set values are included.
|
||||
"""
|
||||
settings = {
|
||||
"speaker": self._voice_id,
|
||||
"modelId": self._model,
|
||||
"audioFormat": "pcm",
|
||||
"samplingRate": self.sample_rate or 0,
|
||||
}
|
||||
if self._params.language:
|
||||
settings["lang"] = self.language_to_service_language(self._params.language) or "eng"
|
||||
if self._params.segment is not None:
|
||||
settings["segment"] = self._params.segment
|
||||
|
||||
if self._model == "arcana":
|
||||
if self._params.repetition_penalty is not None:
|
||||
settings["repetition_penalty"] = self._params.repetition_penalty
|
||||
if self._params.temperature is not None:
|
||||
settings["temperature"] = self._params.temperature
|
||||
if self._params.top_p is not None:
|
||||
settings["top_p"] = self._params.top_p
|
||||
else: # mistv2/mist
|
||||
if self._params.speed_alpha is not None:
|
||||
settings["speedAlpha"] = self._params.speed_alpha
|
||||
if self._params.reduce_latency is not None:
|
||||
settings["reduceLatency"] = self._params.reduce_latency
|
||||
if self._params.pause_between_brackets is not None:
|
||||
settings["pauseBetweenBrackets"] = json.dumps(self._params.pause_between_brackets)
|
||||
if self._params.phonemize_between_brackets is not None:
|
||||
settings["phonemizeBetweenBrackets"] = json.dumps(
|
||||
self._params.phonemize_between_brackets
|
||||
)
|
||||
if self._params.no_text_normalization is not None:
|
||||
settings["noTextNormalization"] = json.dumps(self._params.no_text_normalization)
|
||||
if self._params.save_oovs is not None:
|
||||
settings["saveOovs"] = json.dumps(self._params.save_oovs)
|
||||
|
||||
return settings
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Update the TTS model.
|
||||
"""Update the TTS model and reconnect.
|
||||
|
||||
Args:
|
||||
model: The model name to use for synthesis.
|
||||
"""
|
||||
self._model = model
|
||||
self._settings = self._build_settings()
|
||||
await super().set_model(model)
|
||||
if self._websocket:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
# A set of Rime-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
@@ -223,18 +270,74 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
return f"[{text}]"
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
"""Update service settings and reconnect if necessary.
|
||||
|
||||
Since all settings are WebSocket URL query parameters,
|
||||
any setting change requires reconnecting to apply the new values.
|
||||
"""
|
||||
prev_settings = self._settings.copy()
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
|
||||
needs_reconnect = False
|
||||
|
||||
if "voice" in settings or "voice_id" in settings:
|
||||
self._settings["speaker"] = self._voice_id
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
if prev_settings.get("speaker") != self._voice_id:
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
needs_reconnect = True
|
||||
|
||||
if "model" in settings:
|
||||
self._settings = self._build_settings()
|
||||
needs_reconnect = True
|
||||
|
||||
if "language" in settings:
|
||||
new_lang = self.language_to_service_language(settings["language"])
|
||||
if new_lang and new_lang != prev_settings.get("lang"):
|
||||
logger.info(f"Updating language to: [{new_lang}]")
|
||||
self._settings["lang"] = new_lang
|
||||
needs_reconnect = True
|
||||
|
||||
# Arcana params
|
||||
for key, settings_key in [
|
||||
("repetition_penalty", "repetition_penalty"),
|
||||
("temperature", "temperature"),
|
||||
("top_p", "top_p"),
|
||||
]:
|
||||
if key in settings and settings[key] != prev_settings.get(settings_key):
|
||||
self._settings[settings_key] = settings[key]
|
||||
needs_reconnect = True
|
||||
|
||||
# Mistv2 params
|
||||
for key, settings_key in [
|
||||
("speed_alpha", "speedAlpha"),
|
||||
("reduce_latency", "reduceLatency"),
|
||||
]:
|
||||
if key in settings and settings[key] != prev_settings.get(settings_key):
|
||||
self._settings[settings_key] = settings[key]
|
||||
needs_reconnect = True
|
||||
|
||||
# Mistv2 boolean params (need json.dumps)
|
||||
for key, settings_key in [
|
||||
("pause_between_brackets", "pauseBetweenBrackets"),
|
||||
("phonemize_between_brackets", "phonemizeBetweenBrackets"),
|
||||
("no_text_normalization", "noTextNormalization"),
|
||||
("save_oovs", "saveOovs"),
|
||||
]:
|
||||
if key in settings and json.dumps(settings[key]) != prev_settings.get(settings_key):
|
||||
self._settings[settings_key] = json.dumps(settings[key])
|
||||
needs_reconnect = True
|
||||
|
||||
if "segment" in settings and settings["segment"] != prev_settings.get("segment"):
|
||||
self._settings["segment"] = settings["segment"]
|
||||
needs_reconnect = True
|
||||
|
||||
if needs_reconnect and self._websocket:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Rime API."""
|
||||
msg = {"text": text, "contextId": self._context_id}
|
||||
msg = {"text": text, "contextId": self.get_active_audio_context_id()}
|
||||
if self._extra_msg_fields:
|
||||
msg |= self._extra_msg_fields
|
||||
self._extra_msg_fields = {}
|
||||
@@ -255,7 +358,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["samplingRate"] = self.sample_rate
|
||||
self._settings = self._build_settings()
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -301,7 +404,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
params = "&".join(f"{k}={v}" for k, v in self._settings.items())
|
||||
params = "&".join(f"{k}={v}" for k, v in self._settings.items() if v is not None)
|
||||
url = f"{self._url}?{params}"
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
@@ -322,7 +425,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -334,11 +437,11 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by clearing current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
if self._context_id:
|
||||
if context_id:
|
||||
await self._get_websocket().send(json.dumps(self._build_clear_msg()))
|
||||
self._context_id = None
|
||||
|
||||
def _calculate_word_times(self, words: list, starts: list, ends: list) -> list:
|
||||
"""Calculate word timing pairs with proper spacing and punctuation.
|
||||
@@ -371,19 +474,20 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._context_id or not self._websocket:
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
return
|
||||
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
await self._get_websocket().send(json.dumps({"operation": "flush"}))
|
||||
self._context_id = None
|
||||
self.reset_active_audio_context()
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
if not msg or not self.audio_context_available(msg["contextId"]):
|
||||
if not msg or not self.audio_context_available(msg.get("contextId")):
|
||||
continue
|
||||
|
||||
context_id = msg["contextId"]
|
||||
@@ -418,7 +522,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
self._context_id = None
|
||||
self.reset_active_audio_context()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions.
|
||||
@@ -449,12 +553,11 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
if not self._context_id:
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
self._context_id = context_id
|
||||
await self.create_audio_context(self._context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
@@ -626,20 +729,18 @@ class RimeHttpTTSService(TTSService):
|
||||
class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
"""Pipecat TTS service for Rime's non-JSON WebSocket API.
|
||||
|
||||
.. deprecated:: 0.0.102
|
||||
Arcana now supports JSON WebSocket with word-level timestamps via the
|
||||
``wss://users-ws.rime.ai/ws3`` endpoint. Use :class:`RimeTTSService`
|
||||
with ``model="arcana"`` instead.
|
||||
|
||||
This service enables Text-to-Speech synthesis over WebSocket endpoints
|
||||
that require plain text (not JSON) messages and return raw audio bytes.
|
||||
It is designed for use with TTS models like Arcana, which currently do
|
||||
not support JSON-based WebSocket protocols (though this may change in
|
||||
the future).
|
||||
|
||||
Limitations:
|
||||
- Does not support word-level timestamps or context IDs.
|
||||
- Intended specifically for integrations where the TTS provider only
|
||||
accepts and returns non-JSON messages.
|
||||
|
||||
Note:
|
||||
- Arcana and similar models may add JSON WebSocket support in the
|
||||
future. This service focuses on the current plain text protocol.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
|
||||
@@ -119,10 +119,10 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
use_translate_method=True,
|
||||
),
|
||||
"saaras:v3": ModelConfig(
|
||||
supports_prompt=True,
|
||||
supports_prompt=False,
|
||||
supports_mode=True,
|
||||
supports_language=True,
|
||||
default_language="en-IN",
|
||||
default_language="unknown",
|
||||
default_mode="transcribe",
|
||||
use_translate_endpoint=False,
|
||||
use_translate_method=False,
|
||||
@@ -134,6 +134,18 @@ class SarvamSTTService(STTService):
|
||||
"""Sarvam speech-to-text service.
|
||||
|
||||
Provides real-time speech recognition using Sarvam's WebSocket API.
|
||||
|
||||
Event handlers available (in addition to STTService events):
|
||||
|
||||
- on_connected(service): Connected to Sarvam WebSocket
|
||||
- on_disconnected(service): Disconnected from Sarvam WebSocket
|
||||
- on_connection_error(service, error): Connection error occurred
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_connected")
|
||||
async def on_connected(service):
|
||||
...
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -143,9 +155,9 @@ class SarvamSTTService(STTService):
|
||||
language: Target language for transcription.
|
||||
- saarika:v2.5: Defaults to "unknown" (auto-detect supported)
|
||||
- saaras:v2.5: Not used (auto-detects language)
|
||||
- saaras:v3: Defaults to "en-IN"
|
||||
- saaras:v3: Defaults to "unknown" (auto-detect supported)
|
||||
prompt: Optional prompt to guide transcription/translation style/context.
|
||||
Only applicable to saaras models (v2.5 and v3). Defaults to None.
|
||||
Only applicable to saaras:v2.5. Defaults to None.
|
||||
mode: Mode of operation for saaras:v3 models only. Options: transcribe, translate,
|
||||
verbatim, translit, codemix. Defaults to "transcribe" for saaras:v3.
|
||||
vad_signals: Enable VAD signals in response. Defaults to None.
|
||||
@@ -167,6 +179,8 @@ class SarvamSTTService(STTService):
|
||||
input_audio_codec: str = "wav",
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = SARVAM_TTFS_P99,
|
||||
keepalive_timeout: Optional[float] = None,
|
||||
keepalive_interval: float = 5.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Sarvam STT service.
|
||||
@@ -176,12 +190,15 @@ class SarvamSTTService(STTService):
|
||||
model: Sarvam model to use for transcription. Allowed values:
|
||||
- "saarika:v2.5": Standard STT model
|
||||
- "saaras:v2.5": STT-Translate model (auto-detects language, supports prompts)
|
||||
- "saaras:v3": Advanced STT model (supports mode and prompts)
|
||||
- "saaras:v3": Advanced STT model (supports mode)
|
||||
sample_rate: Audio sample rate. Defaults to 16000 if not specified.
|
||||
input_audio_codec: Audio codec/format of the input file. Defaults to "wav".
|
||||
params: Configuration parameters for Sarvam STT service.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
keepalive_timeout: Seconds of no audio before sending silence to keep the
|
||||
connection alive. None disables keepalive.
|
||||
keepalive_interval: Seconds between idle checks when keepalive is enabled.
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
params = params or SarvamSTTService.InputParams()
|
||||
@@ -203,7 +220,13 @@ class SarvamSTTService(STTService):
|
||||
f"Model '{model}' does not support language parameter (auto-detects language)."
|
||||
)
|
||||
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
keepalive_timeout=keepalive_timeout,
|
||||
keepalive_interval=keepalive_interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.set_model_name(model)
|
||||
self._api_key = api_key
|
||||
@@ -424,13 +447,32 @@ class SarvamSTTService(STTService):
|
||||
if self._config.supports_mode and self._mode is not None:
|
||||
connect_kwargs["mode"] = self._mode
|
||||
|
||||
# Prompt support differs across sarvamai versions. Prefer connect-time prompt
|
||||
# when available and gracefully degrade if the SDK doesn't accept it.
|
||||
if self._prompt is not None and self._config.supports_prompt:
|
||||
connect_kwargs["prompt"] = self._prompt
|
||||
|
||||
def _connect_with_sdk_headers(connect_fn, **kwargs):
|
||||
# Different SDK versions may use different kwarg names.
|
||||
for header_kw in ("headers", "additional_headers", "extra_headers"):
|
||||
# If prompt is unsupported at connect-time, retry without it.
|
||||
attempts = [kwargs]
|
||||
if "prompt" in kwargs:
|
||||
attempts.append({k: v for k, v in kwargs.items() if k != "prompt"})
|
||||
|
||||
last_type_error = None
|
||||
for attempt_kwargs in attempts:
|
||||
for header_kw in ("headers", "additional_headers", "extra_headers"):
|
||||
try:
|
||||
return connect_fn(**attempt_kwargs, **{header_kw: self._sdk_headers})
|
||||
except TypeError as e:
|
||||
last_type_error = e
|
||||
try:
|
||||
return connect_fn(**kwargs, **{header_kw: self._sdk_headers})
|
||||
except TypeError:
|
||||
pass
|
||||
return connect_fn(**attempt_kwargs)
|
||||
except TypeError as e:
|
||||
last_type_error = e
|
||||
|
||||
if last_type_error is not None:
|
||||
raise last_type_error
|
||||
return connect_fn(**kwargs)
|
||||
|
||||
# Choose the appropriate endpoint based on model configuration
|
||||
@@ -448,9 +490,11 @@ class SarvamSTTService(STTService):
|
||||
# Enter the async context manager
|
||||
self._socket_client = await self._websocket_context.__aenter__()
|
||||
|
||||
# Set prompt if provided (only for models that support prompts)
|
||||
# Fallback for SDKs that support runtime prompt updates.
|
||||
if self._prompt is not None and self._config.supports_prompt:
|
||||
await self._socket_client.set_prompt(self._prompt)
|
||||
prompt_setter = getattr(self._socket_client, "set_prompt", None)
|
||||
if callable(prompt_setter):
|
||||
await prompt_setter(self._prompt)
|
||||
|
||||
# Register event handler for incoming messages
|
||||
def _message_handler(message):
|
||||
@@ -463,6 +507,8 @@ class SarvamSTTService(STTService):
|
||||
# Start receive task using Pipecat's task management
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
self._create_keepalive_task()
|
||||
|
||||
logger.info("Connected to Sarvam successfully")
|
||||
|
||||
except ApiError as e:
|
||||
@@ -476,6 +522,8 @@ class SarvamSTTService(STTService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Sarvam WebSocket API using SDK."""
|
||||
await self._cancel_keepalive_task()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
@@ -600,6 +648,32 @@ class SarvamSTTService(STTService):
|
||||
}
|
||||
return mapping.get(language_code, Language.HI_IN)
|
||||
|
||||
def _is_keepalive_ready(self) -> bool:
|
||||
"""Check if the Sarvam SDK websocket client is connected."""
|
||||
return self._socket_client is not None
|
||||
|
||||
async def _send_keepalive(self, silence: bytes):
|
||||
"""Send silent audio via the Sarvam SDK to keep the connection alive.
|
||||
|
||||
Args:
|
||||
silence: Silent 16-bit mono PCM audio bytes.
|
||||
"""
|
||||
audio_base64 = base64.b64encode(silence).decode("utf-8")
|
||||
encoding = (
|
||||
self._input_audio_codec
|
||||
if self._input_audio_codec.startswith("audio/")
|
||||
else f"audio/{self._input_audio_codec}"
|
||||
)
|
||||
method_kwargs = {
|
||||
"audio": audio_base64,
|
||||
"encoding": encoding,
|
||||
"sample_rate": self.sample_rate,
|
||||
}
|
||||
if self._config.use_translate_method:
|
||||
await self._socket_client.translate(**method_kwargs)
|
||||
else:
|
||||
await self._socket_client.transcribe(**method_kwargs)
|
||||
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
@@ -131,7 +131,6 @@ class SimliVideoService(FrameProcessor):
|
||||
# Build SimliConfig from new parameters
|
||||
# Only pass optional parameters if explicitly provided to use SimliConfig defaults
|
||||
config_kwargs = {
|
||||
"apiKey": api_key,
|
||||
"faceId": face_id,
|
||||
}
|
||||
if params.max_session_length is not None:
|
||||
@@ -153,10 +152,10 @@ class SimliVideoService(FrameProcessor):
|
||||
config.maxIdleTime += 5
|
||||
config.maxSessionLength += 5
|
||||
self._simli_client = SimliClient(
|
||||
api_key=api_key,
|
||||
config=config,
|
||||
latencyInterval=latency_interval,
|
||||
simliURL=simli_url,
|
||||
enable_logging=params.enable_logging or False,
|
||||
enableSFU=True,
|
||||
)
|
||||
|
||||
self._pipecat_resampler: AudioResampler = None
|
||||
@@ -173,7 +172,7 @@ class SimliVideoService(FrameProcessor):
|
||||
"""Start the connection to Simli service and begin processing tasks."""
|
||||
try:
|
||||
if not self._initialized:
|
||||
await self._simli_client.Initialize()
|
||||
await self._simli_client.start()
|
||||
self._initialized = True
|
||||
|
||||
# Create task to consume and process audio and video
|
||||
|
||||
@@ -86,6 +86,16 @@ class SpeechmaticsSTTService(STTService):
|
||||
This service provides real-time speech-to-text transcription using the Speechmatics API.
|
||||
It supports partial and final transcriptions, multiple languages, various audio formats,
|
||||
and speaker diarization.
|
||||
|
||||
Event handlers available (in addition to STTService events):
|
||||
|
||||
- on_speakers_result(service, speakers): Speaker diarization results received
|
||||
|
||||
Example::
|
||||
|
||||
@stt.event_handler("on_speakers_result")
|
||||
async def on_speakers_result(service, speakers):
|
||||
...
|
||||
"""
|
||||
|
||||
# Export related classes as class attributes
|
||||
|
||||
@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
MetricsFrame,
|
||||
ServiceSwitcherRequestMetadataFrame,
|
||||
StartFrame,
|
||||
STTMetadataFrame,
|
||||
@@ -31,7 +30,6 @@ from pipecat.frames.frames import (
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.stt_latency import DEFAULT_TTFS_P99
|
||||
@@ -49,6 +47,12 @@ class STTService(AIService):
|
||||
muting, settings management, and audio processing. Subclasses must implement
|
||||
the run_stt method to provide actual speech recognition.
|
||||
|
||||
Includes an optional keepalive mechanism that sends silent audio when no real
|
||||
audio has been sent for a configurable timeout, preventing servers from closing
|
||||
idle connections (e.g. when behind a ServiceSwitcher). Subclasses that enable
|
||||
keepalive must override ``_send_keepalive()`` to deliver the silence in the
|
||||
appropriate service-specific protocol.
|
||||
|
||||
Event handlers:
|
||||
on_connected: Called when connected to the STT service.
|
||||
on_disconnected: Called when disconnected from the STT service.
|
||||
@@ -76,6 +80,8 @@ class STTService(AIService):
|
||||
sample_rate: Optional[int] = None,
|
||||
stt_ttfb_timeout: float = 2.0,
|
||||
ttfs_p99_latency: Optional[float] = None,
|
||||
keepalive_timeout: Optional[float] = None,
|
||||
keepalive_interval: float = 5.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the STT service.
|
||||
@@ -95,6 +101,10 @@ class STTService(AIService):
|
||||
This is broadcast via STTMetadataFrame at pipeline start for downstream
|
||||
processors (e.g., turn strategies) to optimize timing. Subclasses provide
|
||||
measured defaults; pass a value here to override for your deployment.
|
||||
keepalive_timeout: Seconds of no audio before sending silence to keep the
|
||||
connection alive. None disables keepalive. Useful for services that
|
||||
close idle connections (e.g. behind a ServiceSwitcher).
|
||||
keepalive_interval: Seconds between idle checks when keepalive is enabled.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
@@ -102,7 +112,6 @@ class STTService(AIService):
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._tracing_enabled: bool = False
|
||||
self._muted: bool = False
|
||||
self._user_id: str = ""
|
||||
self._ttfs_p99_latency = ttfs_p99_latency
|
||||
@@ -110,12 +119,16 @@ class STTService(AIService):
|
||||
# STT TTFB tracking state
|
||||
self._stt_ttfb_timeout = stt_ttfb_timeout
|
||||
self._ttfb_timeout_task: Optional[asyncio.Task] = None
|
||||
self._speech_end_time: Optional[float] = None
|
||||
self._user_speaking: bool = False
|
||||
self._last_transcription_time: Optional[float] = None
|
||||
self._finalize_pending: bool = False
|
||||
self._finalize_requested: bool = False
|
||||
|
||||
# Keepalive state
|
||||
self._keepalive_timeout = keepalive_timeout
|
||||
self._keepalive_interval = keepalive_interval
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
self._last_audio_time: float = 0
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
@@ -202,12 +215,12 @@ class STTService(AIService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up STT service resources."""
|
||||
await super().cleanup()
|
||||
await self._cancel_ttfb_timeout()
|
||||
await self._cancel_keepalive_task()
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
logger.info(f"Updating STT settings: {self._settings}")
|
||||
@@ -239,6 +252,8 @@ class STTService(AIService):
|
||||
if self._muted:
|
||||
return
|
||||
|
||||
self._last_audio_time = time.monotonic()
|
||||
|
||||
# UserAudioRawFrame contains a user_id (e.g. Daily, Livekit)
|
||||
if hasattr(frame, "user_id"):
|
||||
self._user_id = frame.user_id
|
||||
@@ -308,23 +323,16 @@ class STTService(AIService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
# Store the transcription time for TTFB calculation
|
||||
self._last_transcription_time = time.time()
|
||||
|
||||
# Set finalized from pending state and auto-reset
|
||||
if self._finalize_pending:
|
||||
frame.finalized = True
|
||||
self._finalize_pending = False
|
||||
|
||||
# If this is a finalized transcription, report TTFB immediately
|
||||
if frame.finalized and self._speech_end_time is not None:
|
||||
ttfb = self._last_transcription_time - self._speech_end_time
|
||||
await self._emit_stt_ttfb_metric(ttfb)
|
||||
if frame.finalized:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Cancel the timeout since we've already reported
|
||||
await self._cancel_ttfb_timeout()
|
||||
# Clear state
|
||||
self._speech_end_time = None
|
||||
self._last_transcription_time = None
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@@ -354,8 +362,6 @@ class STTService(AIService):
|
||||
while user is still speaking.
|
||||
"""
|
||||
await self._cancel_ttfb_timeout()
|
||||
self._speech_end_time = None
|
||||
self._last_transcription_time = None
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
"""Handle VAD user started speaking frame to start tracking transcriptions.
|
||||
@@ -389,7 +395,8 @@ class STTService(AIService):
|
||||
# Calculate the actual speech end time (current time minus VAD stop delay).
|
||||
# This approximates when the last user audio was sent to the STT service,
|
||||
# which we use to measure against the eventual transcription response.
|
||||
self._speech_end_time = frame.timestamp - frame.stop_secs
|
||||
speech_end_time = frame.timestamp - frame.stop_secs
|
||||
await self.start_ttfb_metrics(start_time=speech_end_time)
|
||||
|
||||
# Start timeout task (any previous timeout was cancelled by VADUserStartedSpeakingFrame
|
||||
# or InterruptionFrame)
|
||||
@@ -398,43 +405,79 @@ class STTService(AIService):
|
||||
)
|
||||
|
||||
async def _ttfb_timeout_handler(self):
|
||||
"""Wait for timeout then report TTFB using the last transcription timestamp.
|
||||
"""Wait for timeout then report TTFB.
|
||||
|
||||
This timeout allows the final transcription to arrive before we calculate
|
||||
and report TTFB. If no transcription arrived, no TTFB is reported.
|
||||
"""
|
||||
try:
|
||||
await asyncio.sleep(self._stt_ttfb_timeout)
|
||||
|
||||
# Report TTFB if we have both speech end time and transcription time
|
||||
if self._speech_end_time is not None and self._last_transcription_time is not None:
|
||||
ttfb = self._last_transcription_time - self._speech_end_time
|
||||
await self._emit_stt_ttfb_metric(ttfb)
|
||||
|
||||
# Clear state after reporting
|
||||
self._speech_end_time = None
|
||||
self._last_transcription_time = None
|
||||
await self.stop_ttfb_metrics()
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled (new utterance or interruption), which is expected behavior
|
||||
pass
|
||||
finally:
|
||||
self._ttfb_timeout_task = None
|
||||
|
||||
async def _emit_stt_ttfb_metric(self, ttfb: float):
|
||||
"""Emit STT TTFB metric if value is non-negative.
|
||||
def _create_keepalive_task(self):
|
||||
"""Start the keepalive task if keepalive is enabled."""
|
||||
if self._keepalive_timeout is not None:
|
||||
self._last_audio_time = time.monotonic()
|
||||
self._keepalive_task = self.create_task(
|
||||
self._keepalive_task_handler(), name="keepalive"
|
||||
)
|
||||
|
||||
async def _cancel_keepalive_task(self):
|
||||
"""Stop the keepalive task if running."""
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic silent audio to prevent the server from closing the connection.
|
||||
|
||||
When keepalive is enabled, this task checks periodically if the connection
|
||||
has been idle (no audio sent) for longer than keepalive_timeout seconds.
|
||||
If so, it generates silent 16-bit mono PCM audio and passes it to
|
||||
_send_keepalive() for service-specific formatting and sending.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(self._keepalive_interval)
|
||||
try:
|
||||
if not self._is_keepalive_ready():
|
||||
continue
|
||||
elapsed = time.monotonic() - self._last_audio_time
|
||||
if elapsed < self._keepalive_timeout:
|
||||
continue
|
||||
num_samples = int(self.sample_rate * _KEEPALIVE_SILENCE_DURATION)
|
||||
silence = b"\x00" * (num_samples * 2)
|
||||
await self._send_keepalive(silence)
|
||||
self._last_audio_time = time.monotonic()
|
||||
logger.trace(f"{self} sent keepalive silence")
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
|
||||
def _is_keepalive_ready(self) -> bool:
|
||||
"""Check if the service is ready to send keepalive.
|
||||
|
||||
Subclasses should override this to check their connection state.
|
||||
|
||||
Returns:
|
||||
True if keepalive can be sent.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _send_keepalive(self, silence: bytes):
|
||||
"""Send silent audio to keep the connection alive.
|
||||
|
||||
Subclasses that enable keepalive must override this to deliver silence
|
||||
in their service-specific protocol.
|
||||
|
||||
Args:
|
||||
ttfb: The TTFB value in seconds.
|
||||
silence: Silent 16-bit mono PCM audio bytes.
|
||||
"""
|
||||
if ttfb >= 0:
|
||||
logger.debug(f"{self} TTFB: {ttfb:.3f}s")
|
||||
if self.metrics_enabled:
|
||||
ttfb_data = TTFBMetricsData(
|
||||
processor=self.name,
|
||||
model=self.model_name,
|
||||
value=ttfb,
|
||||
)
|
||||
await super().push_frame(MetricsFrame(data=[ttfb_data]))
|
||||
raise NotImplementedError("Subclasses must override _send_keepalive")
|
||||
|
||||
|
||||
class SegmentedSTTService(STTService):
|
||||
@@ -549,46 +592,27 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
Combines STT functionality with websocket connectivity, providing automatic
|
||||
error handling, reconnection capabilities, and optional silence-based keepalive.
|
||||
|
||||
The keepalive feature sends silent audio when no real audio has been sent for
|
||||
a configurable timeout, preventing servers from closing idle connections (e.g.
|
||||
when behind a ServiceSwitcher). Subclasses can override ``_send_keepalive()``
|
||||
to wrap the silence in a service-specific protocol.
|
||||
The keepalive feature (inherited from STTService) sends silent audio when no
|
||||
real audio has been sent for a configurable timeout, preventing servers from
|
||||
closing idle connections (e.g. when behind a ServiceSwitcher). Subclasses can
|
||||
override ``_send_keepalive()`` to wrap the silence in a service-specific protocol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
reconnect_on_error: bool = True,
|
||||
keepalive_timeout: Optional[float] = None,
|
||||
keepalive_interval: float = 5.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Websocket STT service.
|
||||
|
||||
Args:
|
||||
reconnect_on_error: Whether to automatically reconnect on websocket errors.
|
||||
keepalive_timeout: Seconds of no audio before sending silence to keep the
|
||||
connection alive. None disables keepalive. Useful for services that
|
||||
close idle connections (e.g. behind a ServiceSwitcher).
|
||||
keepalive_interval: Seconds between idle checks when keepalive is enabled.
|
||||
**kwargs: Additional arguments passed to parent classes.
|
||||
**kwargs: Additional arguments passed to parent classes (including
|
||||
keepalive_timeout and keepalive_interval for STTService).
|
||||
"""
|
||||
STTService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._keepalive_timeout = keepalive_timeout
|
||||
self._keepalive_interval = keepalive_interval
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
self._last_audio_time: float = 0
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
|
||||
"""Process an audio frame, tracking the last audio time for keepalive.
|
||||
|
||||
Args:
|
||||
frame: The audio frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
self._last_audio_time = time.monotonic()
|
||||
await super().process_audio_frame(frame, direction)
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect and start keepalive task if enabled."""
|
||||
@@ -612,44 +636,9 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
self._create_keepalive_task()
|
||||
return result
|
||||
|
||||
def _create_keepalive_task(self):
|
||||
"""Start the keepalive task if keepalive is enabled."""
|
||||
if self._keepalive_timeout is not None:
|
||||
self._last_audio_time = time.monotonic()
|
||||
self._keepalive_task = self.create_task(
|
||||
self._keepalive_task_handler(), name="keepalive"
|
||||
)
|
||||
|
||||
async def _cancel_keepalive_task(self):
|
||||
"""Stop the keepalive task if running."""
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic silent audio to prevent the server from closing the connection.
|
||||
|
||||
When keepalive is enabled, this task checks periodically if the connection
|
||||
has been idle (no audio sent) for longer than keepalive_timeout seconds.
|
||||
If so, it generates silent 16-bit mono PCM audio and passes it to
|
||||
_send_keepalive() for service-specific formatting and sending.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(self._keepalive_interval)
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
continue
|
||||
elapsed = time.monotonic() - self._last_audio_time
|
||||
if elapsed < self._keepalive_timeout:
|
||||
continue
|
||||
num_samples = int(self.sample_rate * _KEEPALIVE_SILENCE_DURATION)
|
||||
silence = b"\x00" * (num_samples * 2)
|
||||
await self._send_keepalive(silence)
|
||||
self._last_audio_time = time.monotonic()
|
||||
logger.trace(f"{self} sent keepalive silence")
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
def _is_keepalive_ready(self) -> bool:
|
||||
"""Check if the websocket is open and ready for keepalive."""
|
||||
return self._websocket is not None and self._websocket.state is State.OPEN
|
||||
|
||||
async def _send_keepalive(self, silence: bytes):
|
||||
"""Send silent audio over the websocket to keep the connection alive.
|
||||
|
||||
@@ -208,8 +208,6 @@ class TTSService(AIService):
|
||||
# TODO: Deprecate _text_filters when added to LLMTextProcessor
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
|
||||
if text_filter:
|
||||
import warnings
|
||||
|
||||
@@ -349,7 +347,6 @@ class TTSService(AIService):
|
||||
self._sample_rate = self._init_sample_rate or frame.audio_out_sample_rate
|
||||
if self._push_stop_frames and not self._stop_frame_task:
|
||||
self._stop_frame_task = self.create_task(self._stop_frame_handler())
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the TTS service.
|
||||
@@ -1045,14 +1042,25 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
audio from context ID "A" will be played first.
|
||||
"""
|
||||
|
||||
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
|
||||
_CONTEXT_KEEPALIVE = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
reuse_context_id_within_turn: bool = True,
|
||||
reconnect_on_error: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Audio Context TTS service.
|
||||
|
||||
Args:
|
||||
reuse_context_id_within_turn: Whether the service should reuse context IDs within the same turn.
|
||||
reconnect_on_error: Whether to automatically reconnect on websocket errors.
|
||||
**kwargs: Additional arguments passed to the parent WebsocketTTSService.
|
||||
"""
|
||||
super().__init__(reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._reuse_context_id_within_turn = reuse_context_id_within_turn
|
||||
self._context_id = None
|
||||
self._contexts: Dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = None
|
||||
|
||||
@@ -1062,6 +1070,10 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
Args:
|
||||
context_id: Unique identifier for the audio context.
|
||||
"""
|
||||
# Set the context ID if not already set
|
||||
if not self._context_id:
|
||||
self._context_id = context_id
|
||||
|
||||
await self._contexts_queue.put(context_id)
|
||||
self._contexts[context_id] = asyncio.Queue()
|
||||
logger.trace(f"{self} created audio context {context_id}")
|
||||
@@ -1094,6 +1106,32 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
else:
|
||||
logger.warning(f"{self} unable to remove context {context_id}")
|
||||
|
||||
def has_active_audio_context(self) -> bool:
|
||||
"""Check if there is an active audio context.
|
||||
|
||||
Returns:
|
||||
True if an active audio context exists, False otherwise.
|
||||
"""
|
||||
return self._context_id is not None and self.audio_context_available(self._context_id)
|
||||
|
||||
def get_active_audio_context_id(self) -> Optional[str]:
|
||||
"""Get the active audio context ID.
|
||||
|
||||
Returns:
|
||||
The active context ID, or None if no context is active.
|
||||
"""
|
||||
return self._context_id
|
||||
|
||||
async def remove_active_audio_context(self):
|
||||
"""Remove the active audio context."""
|
||||
if self._context_id:
|
||||
await self.remove_audio_context(self._context_id)
|
||||
self.reset_active_audio_context()
|
||||
|
||||
def reset_active_audio_context(self):
|
||||
"""Reset the active audio context."""
|
||||
self._context_id = None
|
||||
|
||||
def audio_context_available(self, context_id: str) -> bool:
|
||||
"""Check whether the given audio context is registered.
|
||||
|
||||
@@ -1105,6 +1143,26 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
"""
|
||||
return context_id in self._contexts
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate or reuse a context ID based on concurrent TTS support.
|
||||
|
||||
If _reuse_context_id_within_turn is False and a context already exists,
|
||||
the existing context ID is returned. Otherwise, a new unique context
|
||||
ID is generated.
|
||||
|
||||
Returns:
|
||||
A context ID string for the TTS request.
|
||||
"""
|
||||
if self._reuse_context_id_within_turn and self._context_id:
|
||||
self._refresh_active_audio_context()
|
||||
return self._context_id
|
||||
return super().create_context_id()
|
||||
|
||||
def _refresh_active_audio_context(self):
|
||||
"""Signal that the audio context is still in use, resetting the timeout."""
|
||||
if self.has_active_audio_context():
|
||||
self._contexts[self._context_id].put_nowait(AudioContextTTSService._CONTEXT_KEEPALIVE)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the audio context TTS service.
|
||||
|
||||
@@ -1140,6 +1198,7 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self._stop_audio_context_task()
|
||||
self.reset_active_audio_context()
|
||||
self._create_audio_context_task()
|
||||
|
||||
def _create_audio_context_task(self):
|
||||
@@ -1158,6 +1217,7 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
running = True
|
||||
while running:
|
||||
context_id = await self._contexts_queue.get()
|
||||
self._context_id = context_id
|
||||
|
||||
if context_id:
|
||||
# Process the audio context until the context doesn't have more
|
||||
@@ -1166,11 +1226,15 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
|
||||
# We just finished processing the context, so we can safely remove it.
|
||||
del self._contexts[context_id]
|
||||
self.reset_active_audio_context()
|
||||
|
||||
# Append some silence between sentences.
|
||||
silence = b"\x00" * self.sample_rate
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=silence, sample_rate=self.sample_rate, num_channels=1
|
||||
audio=silence,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
@@ -1186,6 +1250,10 @@ class AudioContextTTSService(WebsocketTTSService):
|
||||
while running:
|
||||
try:
|
||||
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
|
||||
if frame is AudioContextTTSService._CONTEXT_KEEPALIVE:
|
||||
# Context is still in use, reset the timeout.
|
||||
continue
|
||||
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
running = frame is not None
|
||||
|
||||
@@ -237,6 +237,18 @@ class BaseOutputTransport(FrameProcessor):
|
||||
else:
|
||||
await self._write_dtmf_audio(frame)
|
||||
|
||||
async def write_transport_frame(self, frame: Frame):
|
||||
"""Handle a queued frame after preceding audio has been sent.
|
||||
|
||||
Override in transport subclasses to handle custom frame types that
|
||||
flow through the audio queue. Called by the media sender after the
|
||||
frame has waited for any preceding audio to finish.
|
||||
|
||||
Args:
|
||||
frame: The frame to handle.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _supports_native_dtmf(self) -> bool:
|
||||
"""Override in transport implementations that support native DTMF.
|
||||
|
||||
@@ -613,6 +625,11 @@ class BaseOutputTransport(FrameProcessor):
|
||||
downstream_frame.transport_destination = self._destination
|
||||
upstream_frame = BotStartedSpeakingFrame()
|
||||
upstream_frame.transport_destination = self._destination
|
||||
|
||||
# Setting the siblings id
|
||||
upstream_frame.broadcast_sibling_id = downstream_frame.id
|
||||
downstream_frame.broadcast_sibling_id = upstream_frame.id
|
||||
|
||||
await self._transport.push_frame(downstream_frame)
|
||||
await self._transport.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
@@ -635,6 +652,11 @@ class BaseOutputTransport(FrameProcessor):
|
||||
downstream_frame.transport_destination = self._destination
|
||||
upstream_frame = BotStoppedSpeakingFrame()
|
||||
upstream_frame.transport_destination = self._destination
|
||||
|
||||
# Setting the siblings id
|
||||
upstream_frame.broadcast_sibling_id = downstream_frame.id
|
||||
downstream_frame.broadcast_sibling_id = upstream_frame.id
|
||||
|
||||
await self._transport.push_frame(downstream_frame)
|
||||
await self._transport.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
@@ -681,6 +703,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._transport.send_message(frame)
|
||||
elif isinstance(frame, OutputDTMFFrame):
|
||||
await self._transport.write_dtmf(frame)
|
||||
else:
|
||||
await self._transport.write_transport_frame(frame)
|
||||
|
||||
def _next_frame(self) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate the next frame for audio processing.
|
||||
|
||||
@@ -15,7 +15,7 @@ import asyncio
|
||||
import time
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
@@ -25,7 +25,7 @@ from pydantic import BaseModel
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
ControlFrame,
|
||||
DataFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
@@ -183,34 +183,44 @@ class DailyInputTransportMessageUrgentFrame(DailyInputTransportMessageFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyUpdateRemoteParticipantsFrame(ControlFrame):
|
||||
"""Frame to update remote participants in Daily calls.
|
||||
class DailySIPTransferFrame(DataFrame):
|
||||
"""SIP call transfer frame for transport queuing.
|
||||
|
||||
.. deprecated:: 0.0.87
|
||||
`DailyUpdateRemoteParticipantsFrame` is deprecated and will be removed in a future version.
|
||||
Create your own custom frame and use a custom processor to handle it or use, for example,
|
||||
`on_after_push_frame` event instead in the output transport.
|
||||
A SIP call transfer that will be queued. The transfer will happen after any
|
||||
preceding audio finishes playing, allowing the bot to complete its current
|
||||
utterance before the transfer occurs.
|
||||
|
||||
Parameters:
|
||||
settings: SIP call transfer settings.
|
||||
"""
|
||||
|
||||
settings: Mapping[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailySIPReferFrame(DataFrame):
|
||||
"""SIP REFER frame for transport queuing.
|
||||
|
||||
A SIP REFER that will be queued. The REFER will happen after any preceding
|
||||
audio finishes playing, allowing the bot to complete its current utterance
|
||||
before the REFER occurs.
|
||||
|
||||
Parameters:
|
||||
settings: SIP REFER settings.
|
||||
"""
|
||||
|
||||
settings: Mapping[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyUpdateRemoteParticipantsFrame(DataFrame):
|
||||
"""Frame to update remote participants in Daily calls.
|
||||
|
||||
Parameters:
|
||||
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
|
||||
"""
|
||||
|
||||
remote_participants: Mapping[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"DailyUpdateRemoteParticipantsFrame is deprecated and will be removed in a future version."
|
||||
"Instead, create your own custom frame and handle it in the "
|
||||
'`@transport.output().event_handler("on_after_push_frame")` event handler or a '
|
||||
"custom processor.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
remote_participants: Mapping[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class WebRTCVADAnalyzer(VADAnalyzer):
|
||||
@@ -501,6 +511,7 @@ class DailyTransportClient(EventHandler):
|
||||
self._event_task = None
|
||||
self._audio_task = None
|
||||
self._video_task = None
|
||||
self._join_message_queue: list = []
|
||||
|
||||
# Input and ouput sample rates. They will be initialize on setup().
|
||||
self._in_sample_rate = 0
|
||||
@@ -567,7 +578,8 @@ class DailyTransportClient(EventHandler):
|
||||
error: An error description or None.
|
||||
"""
|
||||
if not self._joined:
|
||||
return "Unable to send messages before joining."
|
||||
self._join_message_queue.append(frame)
|
||||
return None
|
||||
|
||||
participant_id = None
|
||||
if isinstance(
|
||||
@@ -768,6 +780,8 @@ class DailyTransportClient(EventHandler):
|
||||
await self._callbacks.on_joined(data)
|
||||
|
||||
self._joined_event.set()
|
||||
|
||||
await self._flush_join_messages()
|
||||
else:
|
||||
error_msg = f"Error joining {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
@@ -1541,6 +1555,12 @@ class DailyTransportClient(EventHandler):
|
||||
await callback(*args)
|
||||
queue.task_done()
|
||||
|
||||
async def _flush_join_messages(self):
|
||||
"""Send any messages that were queued before join completed."""
|
||||
for frame in self._join_message_queue:
|
||||
await self.send_message(frame)
|
||||
self._join_message_queue.clear()
|
||||
|
||||
def _get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
"""Get the event loop from the task manager."""
|
||||
if not self._task_manager:
|
||||
@@ -1946,18 +1966,6 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process outgoing frames, including transport messages.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
|
||||
async def send_message(
|
||||
self, frame: OutputTransportMessageFrame | OutputTransportMessageUrgentFrame
|
||||
):
|
||||
@@ -1968,7 +1976,7 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
"""
|
||||
error = await self._client.send_message(frame)
|
||||
if error:
|
||||
logger.error(f"Unable to send message: {error}")
|
||||
await self.push_error(f"Unable to send message: {error}")
|
||||
|
||||
async def register_video_destination(self, destination: str):
|
||||
"""Register a video output destination.
|
||||
@@ -2011,6 +2019,25 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
"""
|
||||
return await self._client.write_video_frame(frame)
|
||||
|
||||
async def write_transport_frame(self, frame: Frame):
|
||||
"""Handle queued SIP frames after preceding audio has been sent.
|
||||
|
||||
Args:
|
||||
frame: The frame to handle.
|
||||
"""
|
||||
if isinstance(frame, DailySIPTransferFrame):
|
||||
error = await self._client.sip_call_transfer(frame.settings)
|
||||
if error:
|
||||
await self.push_error(f"Unable to transfer SIP call: {error}")
|
||||
elif isinstance(frame, DailySIPReferFrame):
|
||||
error = await self._client.sip_refer(frame.settings)
|
||||
if error:
|
||||
await self.push_error(f"Unable to perform SIP REFER: {error}")
|
||||
elif isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
error = await self._client.update_remote_participants(frame.remote_participants)
|
||||
if error:
|
||||
await self.push_error(f"Unable to update remote participants: {error}")
|
||||
|
||||
def _supports_native_dtmf(self) -> bool:
|
||||
"""Daily supports native DTMF via telephone events.
|
||||
|
||||
@@ -2039,6 +2066,61 @@ class DailyTransport(BaseTransport):
|
||||
Provides comprehensive Daily integration including audio/video streaming,
|
||||
transcription, recording, dial-in/out functionality, and real-time communication
|
||||
features for conversational AI applications.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_joined: Called when the bot joins the room. Args: (data: dict)
|
||||
- on_left: Called when the bot leaves the room.
|
||||
- on_before_leave: [sync] Called just before the bot leaves the room.
|
||||
- on_error: Called when a transport error occurs. Args: (error: str)
|
||||
- on_call_state_updated: Called when the call state changes. Args: (state: str)
|
||||
- on_first_participant_joined: Called when the first participant joins.
|
||||
Args: (participant: dict)
|
||||
- on_participant_joined: Called when any participant joins.
|
||||
Args: (participant: dict)
|
||||
- on_participant_left: Called when a participant leaves.
|
||||
Args: (participant: dict, reason: str)
|
||||
- on_participant_updated: Called when a participant's state changes.
|
||||
Args: (participant: dict)
|
||||
- on_client_connected: Called when a participant connects (alias for
|
||||
on_participant_joined). Args: (participant: dict)
|
||||
- on_client_disconnected: Called when a participant disconnects (alias for
|
||||
on_participant_left). Args: (participant: dict)
|
||||
- on_active_speaker_changed: Called when the active speaker changes.
|
||||
Args: (participant: dict)
|
||||
- on_app_message: Called when an app message is received.
|
||||
Args: (message: Any, sender: str)
|
||||
- on_transcription_message: Called when a transcription message is received.
|
||||
Args: (message: dict)
|
||||
- on_recording_started: Called when recording starts. Args: (status: str)
|
||||
- on_recording_stopped: Called when recording stops. Args: (stream_id: str)
|
||||
- on_recording_error: Called when a recording error occurs.
|
||||
Args: (stream_id: str, message: str)
|
||||
- on_dialin_connected: Called when a dial-in call connects. Args: (data: dict)
|
||||
- on_dialin_ready: Called when the SIP endpoint is ready.
|
||||
Args: (sip_endpoint: str)
|
||||
- on_dialin_stopped: Called when a dial-in call stops. Args: (data: dict)
|
||||
- on_dialin_error: Called when a dial-in error occurs. Args: (data: dict)
|
||||
- on_dialin_warning: Called when a dial-in warning occurs. Args: (data: dict)
|
||||
- on_dialout_answered: Called when a dial-out call is answered. Args: (data: dict)
|
||||
- on_dialout_connected: Called when a dial-out call connects. Args: (data: dict)
|
||||
- on_dialout_stopped: Called when a dial-out call stops. Args: (data: dict)
|
||||
- on_dialout_error: Called when a dial-out error occurs. Args: (data: dict)
|
||||
- on_dialout_warning: Called when a dial-out warning occurs. Args: (data: dict)
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await task.queue_frame(TTSSpeakFrame("Hello!"))
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_app_message")
|
||||
async def on_app_message(transport, message, sender):
|
||||
logger.info(f"Message from {sender}: {message}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -289,6 +289,17 @@ class HeyGenTransport(BaseTransport):
|
||||
When used, the Pipecat bot joins the same virtual room as the HeyGen Avatar and the user.
|
||||
This is achieved by using `HeyGenTransport`, which initiates the conversation via
|
||||
`HeyGenApi` and obtains a room URL that all participants connect to.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_client_connected(transport, participant): Participant connected to the session
|
||||
- on_client_disconnected(transport, participant): Participant disconnected from the session
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -950,6 +950,41 @@ class LiveKitTransport(BaseTransport):
|
||||
Provides comprehensive LiveKit integration including audio streaming, data
|
||||
messaging, participant management, and room event handling for conversational
|
||||
AI applications.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_connected: Called when the bot connects to the room.
|
||||
- on_disconnected: Called when the bot disconnects from the room.
|
||||
- on_before_disconnect: [sync] Called just before the bot disconnects.
|
||||
- on_call_state_updated: Called when the call state changes. Args: (state: str)
|
||||
- on_first_participant_joined: Called when the first participant joins.
|
||||
Args: (participant_id: str)
|
||||
- on_participant_connected: Called when a participant connects.
|
||||
Args: (participant_id: str)
|
||||
- on_participant_disconnected: Called when a participant disconnects.
|
||||
Args: (participant_id: str)
|
||||
- on_participant_left: Called when a participant leaves.
|
||||
Args: (participant_id: str, reason: str)
|
||||
- on_audio_track_subscribed: Called when an audio track is subscribed.
|
||||
Args: (participant_id: str)
|
||||
- on_audio_track_unsubscribed: Called when an audio track is unsubscribed.
|
||||
Args: (participant_id: str)
|
||||
- on_video_track_subscribed: Called when a video track is subscribed.
|
||||
Args: (participant_id: str)
|
||||
- on_video_track_unsubscribed: Called when a video track is unsubscribed.
|
||||
Args: (participant_id: str)
|
||||
- on_data_received: Called when data is received from a participant.
|
||||
Args: (data: bytes, participant_id: str)
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant_id):
|
||||
await task.queue_frame(TTSSpeakFrame("Hello!"))
|
||||
|
||||
@transport.event_handler("on_participant_disconnected")
|
||||
async def on_participant_disconnected(transport, participant_id):
|
||||
await task.queue_frame(EndFrame())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -864,6 +864,18 @@ class SmallWebRTCTransport(BaseTransport):
|
||||
|
||||
Provides bidirectional audio and video streaming over WebRTC connections
|
||||
with support for application messaging and connection event handling.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_client_connected(transport, client): Client connected to WebRTC session
|
||||
- on_client_disconnected(transport, client): Client disconnected from WebRTC session
|
||||
- on_client_message(transport, message, client): Received a data channel message
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -519,7 +519,7 @@ class TavusInputTransport(BaseInputTransport):
|
||||
"""Handle received participant audio data."""
|
||||
frame = InputAudioRawFrame(
|
||||
audio=audio.audio_frames,
|
||||
sample_rate=audio.audio_frames,
|
||||
sample_rate=audio.sample_rate,
|
||||
num_channels=audio.num_channels,
|
||||
)
|
||||
frame.transport_source = audio_source
|
||||
@@ -661,6 +661,17 @@ class TavusTransport(BaseTransport):
|
||||
When used, the Pipecat bot joins the same virtual room as the Tavus Avatar and the user.
|
||||
This is achieved by using `TavusTransportClient`, which initiates the conversation via
|
||||
`TavusApi` and obtains a room URL that all participants connect to.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_client_connected(transport, participant): Participant connected to the session
|
||||
- on_client_disconnected(transport, participant): Participant disconnected from the session
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -471,6 +471,17 @@ class WebsocketClientTransport(BaseTransport):
|
||||
|
||||
Provides a complete WebSocket client transport implementation with
|
||||
input and output capabilities, connection management, and event handling.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_connected(transport): Connected to WebSocket server
|
||||
- on_disconnected(transport): Disconnected from WebSocket server
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_connected")
|
||||
async def on_connected(transport):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -534,6 +534,18 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
|
||||
Provides bidirectional WebSocket communication with frame serialization,
|
||||
session management, and event handling for client connections and timeouts.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_client_connected(transport, websocket): Client WebSocket connected
|
||||
- on_client_disconnected(transport, websocket): Client WebSocket disconnected
|
||||
- on_session_timeout(transport, websocket): Session timed out
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, websocket):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -421,6 +421,19 @@ class WebsocketServerTransport(BaseTransport):
|
||||
Provides a complete WebSocket server implementation with separate input and
|
||||
output transports, client connection management, and event handling for
|
||||
real-time audio and data streaming applications.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_client_connected(transport, websocket): Client WebSocket connected
|
||||
- on_client_disconnected(transport, websocket): Client WebSocket disconnected
|
||||
- on_session_timeout(transport, websocket): Session timed out
|
||||
- on_websocket_ready(transport): WebSocket server is ready to accept connections
|
||||
|
||||
Example::
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, websocket):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -10,12 +10,15 @@ import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
UserSpeakingFrame,
|
||||
UserIdleTimeoutUpdateFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
@@ -25,14 +28,14 @@ class UserIdleController(BaseObject):
|
||||
"""Controller for managing user idle detection.
|
||||
|
||||
This class monitors user activity and triggers an event when the user has been
|
||||
idle (not speaking) for a configured timeout period. It only starts monitoring
|
||||
after the first conversation activity and does not trigger while the bot is
|
||||
speaking or function calls are in progress.
|
||||
idle (not speaking) for a configured timeout period after the bot finishes
|
||||
speaking. The timer starts when BotStoppedSpeakingFrame is received and is
|
||||
cancelled when someone starts speaking again (UserStartedSpeakingFrame or
|
||||
BotStartedSpeakingFrame).
|
||||
|
||||
The controller tracks activity using continuous frames (UserSpeakingFrame and
|
||||
BotSpeakingFrame) which are emitted repeatedly while speaking is happening, and
|
||||
state-based tracking for function calls (FunctionCallsStartedFrame and
|
||||
FunctionCallResultFrame) which are only sent at start and end.
|
||||
The timer is suppressed while a user turn is in progress to avoid false
|
||||
triggers during interruptions (where BotStoppedSpeakingFrame arrives while
|
||||
the user is still speaking).
|
||||
|
||||
Event handlers available:
|
||||
|
||||
@@ -49,12 +52,13 @@ class UserIdleController(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_idle_timeout: float,
|
||||
user_idle_timeout: float = 0,
|
||||
):
|
||||
"""Initialize the user idle controller.
|
||||
|
||||
Args:
|
||||
user_idle_timeout: Timeout in seconds before considering the user idle.
|
||||
0 disables idle detection.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -62,11 +66,9 @@ class UserIdleController(BaseObject):
|
||||
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
|
||||
self._conversation_started = False
|
||||
self._function_call_in_progress = False
|
||||
|
||||
self.user_idle_event = asyncio.Event()
|
||||
self.user_idle_task: Optional[asyncio.Task] = None
|
||||
self._user_turn_in_progress: bool = False
|
||||
self._function_calls_in_progress: int = 0
|
||||
self._idle_timer_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_user_turn_idle", sync=True)
|
||||
|
||||
@@ -85,19 +87,10 @@ class UserIdleController(BaseObject):
|
||||
"""
|
||||
self._task_manager = task_manager
|
||||
|
||||
if not self.user_idle_task:
|
||||
self.user_idle_task = self.task_manager.create_task(
|
||||
self.user_idle_task_handler(),
|
||||
f"{self}::user_idle_task_handler",
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the controller."""
|
||||
await super().cleanup()
|
||||
|
||||
if self.user_idle_task:
|
||||
await self.task_manager.cancel_task(self.user_idle_task)
|
||||
self.user_idle_task = None
|
||||
await self._cancel_idle_timer()
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to track user activity state.
|
||||
@@ -105,69 +98,60 @@ class UserIdleController(BaseObject):
|
||||
Args:
|
||||
frame: The frame to be processed.
|
||||
"""
|
||||
# Start monitoring on first conversation activity
|
||||
if not self._conversation_started:
|
||||
if isinstance(frame, (UserStartedSpeakingFrame, BotSpeakingFrame)):
|
||||
self._conversation_started = True
|
||||
self.user_idle_event.set()
|
||||
else:
|
||||
return
|
||||
if isinstance(frame, UserIdleTimeoutUpdateFrame):
|
||||
self._user_idle_timeout = frame.timeout
|
||||
if self._user_idle_timeout <= 0:
|
||||
await self._cancel_idle_timer()
|
||||
return
|
||||
|
||||
# Reset idle timer on continuous activity frames
|
||||
if isinstance(frame, (UserSpeakingFrame, BotSpeakingFrame)):
|
||||
await self._handle_activity(frame)
|
||||
# Track function call state (start/end frames, not continuous)
|
||||
if isinstance(frame, BotStoppedSpeakingFrame):
|
||||
# Only start the timer if the user isn't mid-turn and no function
|
||||
# calls are pending.
|
||||
#
|
||||
# Interruption case: the frame order is UserStartedSpeaking →
|
||||
# BotStoppedSpeaking → (user keeps talking) → UserStoppedSpeaking.
|
||||
# Without the user-turn guard the timer would start while the user
|
||||
# is still speaking.
|
||||
#
|
||||
# Function call case: normally FunctionCallsStarted arrives after
|
||||
# BotStoppedSpeaking and cancels the timer directly. But a race
|
||||
# condition can cause FunctionCallsStarted to arrive before
|
||||
# BotStoppedSpeaking when pushing a TTSSpeakFrame in the
|
||||
# on_function_calls_started event handler, so the counter guard
|
||||
# prevents the timer from starting while a function call is in progress.
|
||||
if not self._user_turn_in_progress and self._function_calls_in_progress == 0:
|
||||
await self._start_idle_timer()
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._cancel_idle_timer()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._user_turn_in_progress = True
|
||||
await self._cancel_idle_timer()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._user_turn_in_progress = False
|
||||
elif isinstance(frame, FunctionCallsStartedFrame):
|
||||
await self._handle_function_calls_started(frame)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame)
|
||||
self._function_calls_in_progress += len(frame.function_calls)
|
||||
await self._cancel_idle_timer()
|
||||
elif isinstance(frame, (FunctionCallResultFrame, FunctionCallCancelFrame)):
|
||||
self._function_calls_in_progress = max(0, self._function_calls_in_progress - 1)
|
||||
|
||||
async def _handle_activity(self, _: UserSpeakingFrame | BotSpeakingFrame):
|
||||
"""Handle continuous activity frames that should reset the idle timer.
|
||||
async def _start_idle_timer(self):
|
||||
"""Start (or restart) the idle timer."""
|
||||
if self._user_idle_timeout <= 0:
|
||||
return
|
||||
await self._cancel_idle_timer()
|
||||
self._idle_timer_task = self.task_manager.create_task(
|
||||
self._idle_timer_expired(),
|
||||
f"{self}::idle_timer",
|
||||
)
|
||||
|
||||
These frames are emitted continuously while the user or bot is speaking,
|
||||
so we simply reset the timer whenever we receive them.
|
||||
async def _cancel_idle_timer(self):
|
||||
"""Cancel the idle timer if running."""
|
||||
if self._idle_timer_task:
|
||||
await self.task_manager.cancel_task(self._idle_timer_task)
|
||||
self._idle_timer_task = None
|
||||
|
||||
Args:
|
||||
frame: The activity frame to process.
|
||||
"""
|
||||
self.user_idle_event.set()
|
||||
|
||||
async def _handle_function_calls_started(self, _: FunctionCallsStartedFrame):
|
||||
"""Handle function calls started event.
|
||||
|
||||
Function calls can take longer than the timeout, so we track their state
|
||||
to prevent idle callbacks while they're in progress.
|
||||
|
||||
Args:
|
||||
frame: The FunctionCallsStartedFrame to process.
|
||||
"""
|
||||
self._function_call_in_progress = True
|
||||
self.user_idle_event.set()
|
||||
|
||||
async def _handle_function_call_result(self, _: FunctionCallResultFrame):
|
||||
"""Handle function call result event.
|
||||
|
||||
Args:
|
||||
frame: The FunctionCallResultFrame to process.
|
||||
"""
|
||||
self._function_call_in_progress = False
|
||||
self.user_idle_event.set()
|
||||
|
||||
async def user_idle_task_handler(self):
|
||||
"""Monitors for idle timeout and triggers events.
|
||||
|
||||
Runs in a loop until cancelled. The idle timer is reset whenever activity
|
||||
frames are received (UserSpeakingFrame or BotSpeakingFrame). Function calls
|
||||
are tracked via state since they only send start/end frames. If no activity
|
||||
is detected for the configured timeout period and no function call is in
|
||||
progress, the on_user_turn_idle event is triggered.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self.user_idle_event.wait(), timeout=self._user_idle_timeout)
|
||||
self.user_idle_event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
# Only trigger if conversation has started and no function call is in progress
|
||||
if self._conversation_started and not self._function_call_in_progress:
|
||||
await self._call_event_handler("on_user_turn_idle")
|
||||
async def _idle_timer_expired(self):
|
||||
"""Sleep for the timeout duration then fire the idle event."""
|
||||
await asyncio.sleep(self._user_idle_timeout)
|
||||
self._idle_timer_task = None
|
||||
await self._call_event_handler("on_user_turn_idle")
|
||||
|
||||
@@ -34,12 +34,8 @@ class SpeechTimeoutUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
after the user stops speaking, adjusted by the VAD stop_secs.
|
||||
|
||||
For services that support finalization (TranscriptionFrame.finalized=True),
|
||||
receiving the finalized transcript allows the strategy to shorten the
|
||||
timeout by removing the STT wait component, since only the
|
||||
`user_speech_timeout` portion is still needed. If `user_speech_timeout`
|
||||
has already elapsed when the transcript arrives, the original timeout
|
||||
continues running to provide a buffer for VAD to detect any resumed
|
||||
speech before triggering.
|
||||
the turn can be triggered immediately once the finalized transcript is
|
||||
received and the user resume speaking timeout has elapsed.
|
||||
"""
|
||||
|
||||
def __init__(self, *, user_speech_timeout: float = 0.6, **kwargs):
|
||||
@@ -130,26 +126,8 @@ class SpeechTimeoutUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
self._text += frame.text
|
||||
if frame.finalized:
|
||||
self._transcript_finalized = True
|
||||
# With the transcript finalized, we no longer need to wait for
|
||||
# STT latency. If a timeout is running (from VAD stop), recalculate
|
||||
# to use only user_speech_timeout, potentially shortening the wait.
|
||||
if self._timeout_task and self._vad_stopped_time is not None:
|
||||
elapsed = time.time() - self._vad_stopped_time
|
||||
remaining = self._user_speech_timeout - elapsed
|
||||
if remaining > 0:
|
||||
# Shorten timeout: replace STT+speech timeout with just
|
||||
# remaining speech timeout since STT is done.
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = self.task_manager.create_task(
|
||||
self._timeout_handler(remaining), f"{self}::_timeout_handler"
|
||||
)
|
||||
# If remaining <= 0: user_speech_timeout has elapsed, but the
|
||||
# original timeout (which may include extra STT wait time) is
|
||||
# still running. Let it complete naturally — this provides a
|
||||
# buffer for VAD to detect any resumed speech before triggering.
|
||||
elif self._timeout_task is None:
|
||||
# Timeout already completed, check if we should trigger now
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
# For finalized transcripts, check if we can trigger early
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
# Fallback: handle transcripts when no VAD stop was received.
|
||||
# This handles edge cases where transcripts arrive without VAD firing.
|
||||
@@ -200,10 +178,25 @@ class SpeechTimeoutUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
Conditions:
|
||||
- User is not currently speaking
|
||||
- We have transcription text
|
||||
- The timeout has fully elapsed (timeout task completed)
|
||||
- Either the timeout has elapsed OR we have a finalized transcript
|
||||
and user_speech_timeout has elapsed
|
||||
"""
|
||||
if self._vad_user_speaking or not self._text:
|
||||
return
|
||||
|
||||
# For finalized transcripts, check if user_speech_timeout has elapsed.
|
||||
# If elapsed, trigger user turn stopped immediately. Else, wait for user resume
|
||||
# speaking timeout.
|
||||
if self._transcript_finalized and self._vad_stopped_time is not None:
|
||||
elapsed = time.time() - self._vad_stopped_time
|
||||
if elapsed >= self._user_speech_timeout:
|
||||
# Cancel any remaining timeout since we're triggering now
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
await self.trigger_user_turn_stopped()
|
||||
return
|
||||
|
||||
# For non-finalized, only trigger if timeout task has completed
|
||||
if self._timeout_task is None:
|
||||
await self.trigger_user_turn_stopped()
|
||||
|
||||
@@ -66,7 +66,7 @@ class UserTurnProcessor(FrameProcessor):
|
||||
*,
|
||||
user_turn_strategies: Optional[UserTurnStrategies] = None,
|
||||
user_turn_stop_timeout: float = 5.0,
|
||||
user_idle_timeout: Optional[float] = None,
|
||||
user_idle_timeout: float = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the user turn processor.
|
||||
@@ -75,9 +75,9 @@ class UserTurnProcessor(FrameProcessor):
|
||||
user_turn_strategies: Configured strategies for starting and stopping user turns.
|
||||
user_turn_stop_timeout: Timeout in seconds to automatically stop a user turn
|
||||
if no activity is detected.
|
||||
user_idle_timeout: Optional timeout in seconds for detecting user idle state.
|
||||
If set, the processor will emit an `on_user_turn_idle` event when the user
|
||||
has been idle (not speaking) for this duration. Set to None to disable
|
||||
user_idle_timeout: Timeout in seconds for detecting user idle state.
|
||||
The processor will emit an `on_user_turn_idle` event when the user
|
||||
has been idle (not speaking) for this duration. Set to 0 to disable
|
||||
idle detection.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
@@ -104,13 +104,8 @@ class UserTurnProcessor(FrameProcessor):
|
||||
"on_user_turn_stop_timeout", self._on_user_turn_stop_timeout
|
||||
)
|
||||
|
||||
# Optional user idle controller
|
||||
self._user_idle_controller: Optional[UserIdleController] = None
|
||||
if user_idle_timeout:
|
||||
self._user_idle_controller = UserIdleController(user_idle_timeout=user_idle_timeout)
|
||||
self._user_idle_controller.add_event_handler(
|
||||
"on_user_turn_idle", self._on_user_turn_idle
|
||||
)
|
||||
self._user_idle_controller = UserIdleController(user_idle_timeout=user_idle_timeout)
|
||||
self._user_idle_controller.add_event_handler("on_user_turn_idle", self._on_user_turn_idle)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up processor resources."""
|
||||
@@ -149,14 +144,11 @@ class UserTurnProcessor(FrameProcessor):
|
||||
|
||||
await self._user_turn_controller.process_frame(frame)
|
||||
|
||||
if self._user_idle_controller:
|
||||
await self._user_idle_controller.process_frame(frame)
|
||||
await self._user_idle_controller.process_frame(frame)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
await self._user_turn_controller.setup(self.task_manager)
|
||||
|
||||
if self._user_idle_controller:
|
||||
await self._user_idle_controller.setup(self.task_manager)
|
||||
await self._user_idle_controller.setup(self.task_manager)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._cleanup()
|
||||
@@ -166,9 +158,7 @@ class UserTurnProcessor(FrameProcessor):
|
||||
|
||||
async def _cleanup(self):
|
||||
await self._user_turn_controller.cleanup()
|
||||
|
||||
if self._user_idle_controller:
|
||||
await self._user_idle_controller.cleanup()
|
||||
await self._user_idle_controller.cleanup()
|
||||
|
||||
async def _on_push_frame(
|
||||
self, controller, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
@@ -189,6 +179,8 @@ class UserTurnProcessor(FrameProcessor):
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
@@ -205,6 +197,8 @@ class UserTurnProcessor(FrameProcessor):
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
await self._user_idle_controller.process_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy)
|
||||
|
||||
async def _on_user_turn_stop_timeout(self, controller):
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
|
||||
"""Base OpenTelemetry tracing decorators and utilities for Pipecat.
|
||||
|
||||
.. deprecated:: 0.0.103
|
||||
This module is unused and will be removed in a future release.
|
||||
Service tracing is handled by the decorators in
|
||||
:mod:`pipecat.utils.tracing.service_decorators`.
|
||||
|
||||
This module provides class and method level tracing capabilities
|
||||
similar to the original NVIDIA implementation.
|
||||
"""
|
||||
@@ -16,8 +21,16 @@ import contextlib
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
warnings.warn(
|
||||
"pipecat.utils.tracing.class_decorators is deprecated and will be removed in a future "
|
||||
"release. Use pipecat.utils.tracing.service_decorators instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
|
||||
# Import OpenTelemetry if available
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Conversation context provider for OpenTelemetry tracing in Pipecat.
|
||||
|
||||
This module provides a singleton context provider that manages the current
|
||||
conversation's tracing context, allowing services to create child spans
|
||||
that are properly associated with the conversation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
# Import types for type checking only
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import SpanContext
|
||||
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
|
||||
if is_tracing_available():
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import NonRecordingSpan, SpanContext, set_span_in_context
|
||||
|
||||
|
||||
class ConversationContextProvider:
|
||||
"""Provides access to the current conversation's tracing context.
|
||||
|
||||
This is a singleton that can be used to get the current conversation's
|
||||
span context to create child spans (like turns).
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_current_conversation_context: Optional["Context"] = None
|
||||
_conversation_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the singleton instance.
|
||||
|
||||
Returns:
|
||||
The singleton ConversationContextProvider instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = ConversationContextProvider()
|
||||
return cls._instance
|
||||
|
||||
def set_current_conversation_context(
|
||||
self, span_context: Optional["SpanContext"], conversation_id: Optional[str] = None
|
||||
):
|
||||
"""Set the current conversation context.
|
||||
|
||||
Args:
|
||||
span_context: The span context for the current conversation or None to clear it.
|
||||
conversation_id: Optional ID for the conversation.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
return
|
||||
|
||||
self._conversation_id = conversation_id
|
||||
|
||||
if span_context:
|
||||
# Create a non-recording span from the span context
|
||||
non_recording_span = NonRecordingSpan(span_context)
|
||||
self._current_conversation_context = set_span_in_context(non_recording_span)
|
||||
else:
|
||||
self._current_conversation_context = None
|
||||
|
||||
def get_current_conversation_context(self) -> Optional["Context"]:
|
||||
"""Get the OpenTelemetry context for the current conversation.
|
||||
|
||||
Returns:
|
||||
The current conversation context or None if not available.
|
||||
"""
|
||||
return self._current_conversation_context
|
||||
|
||||
def get_conversation_id(self) -> Optional[str]:
|
||||
"""Get the ID for the current conversation.
|
||||
|
||||
Returns:
|
||||
The current conversation ID or None if not available.
|
||||
"""
|
||||
return self._conversation_id
|
||||
|
||||
def generate_conversation_id(self) -> str:
|
||||
"""Generate a new conversation ID.
|
||||
|
||||
Returns:
|
||||
A new randomly generated UUID string.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def get_current_conversation_context() -> Optional["Context"]:
|
||||
"""Get the OpenTelemetry context for the current conversation.
|
||||
|
||||
Returns:
|
||||
The current conversation context or None if not available.
|
||||
"""
|
||||
provider = ConversationContextProvider.get_instance()
|
||||
return provider.get_current_conversation_context()
|
||||
|
||||
|
||||
def get_conversation_id() -> Optional[str]:
|
||||
"""Get the ID for the current conversation.
|
||||
|
||||
Returns:
|
||||
The current conversation ID or None if not available.
|
||||
"""
|
||||
provider = ConversationContextProvider.get_instance()
|
||||
return provider.get_conversation_id()
|
||||
@@ -25,7 +25,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.utils.tracing.conversation_context_provider import get_current_conversation_context
|
||||
from pipecat.utils.tracing.service_attributes import (
|
||||
add_gemini_live_span_attributes,
|
||||
add_llm_span_attributes,
|
||||
@@ -34,7 +33,6 @@ from pipecat.utils.tracing.service_attributes import (
|
||||
add_tts_span_attributes,
|
||||
)
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
from pipecat.utils.tracing.turn_context_provider import get_current_turn_context
|
||||
|
||||
if is_tracing_available():
|
||||
from opentelemetry import context as context_api
|
||||
@@ -56,6 +54,19 @@ def _noop_decorator(func):
|
||||
return func
|
||||
|
||||
|
||||
def _get_turn_context(self):
|
||||
"""Get the current turn's tracing context if available.
|
||||
|
||||
Args:
|
||||
self: The service instance.
|
||||
|
||||
Returns:
|
||||
The turn context, or None if unavailable.
|
||||
"""
|
||||
tracing_ctx = getattr(self, "_tracing_context", None)
|
||||
return tracing_ctx.get_turn_context() if tracing_ctx else None
|
||||
|
||||
|
||||
def _get_parent_service_context(self):
|
||||
"""Get the parent service span context (internal use only).
|
||||
|
||||
@@ -71,12 +82,14 @@ def _get_parent_service_context(self):
|
||||
if not is_tracing_available():
|
||||
return None
|
||||
|
||||
# The parent span was created when Traceable was initialized and stored as self._span
|
||||
# TODO: Remove this block and delete class_decorators.py once Traceable is removed.
|
||||
# Legacy: support for classes inheriting from Traceable (currently unused, deprecated).
|
||||
if hasattr(self, "_span") and self._span:
|
||||
return trace.set_span_in_context(self._span)
|
||||
|
||||
# Fall back to conversation context if available
|
||||
conversation_context = get_current_conversation_context()
|
||||
# Use the conversation context set by TurnTraceObserver via TracingContext.
|
||||
tracing_ctx = getattr(self, "_tracing_context", None)
|
||||
conversation_context = tracing_ctx.get_conversation_context() if tracing_ctx else None
|
||||
if conversation_context:
|
||||
return conversation_context
|
||||
|
||||
@@ -183,8 +196,7 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
span_name = "tts"
|
||||
|
||||
# Get parent context
|
||||
turn_context = get_current_turn_context()
|
||||
parent_context = turn_context or _get_parent_service_context(self)
|
||||
parent_context = _get_turn_context(self) or _get_parent_service_context(self)
|
||||
|
||||
# Create span
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
@@ -218,19 +230,21 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
|
||||
@functools.wraps(f)
|
||||
async def gen_wrapper(self, text, *args, **kwargs):
|
||||
try:
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
async for item in f(self, text, *args, **kwargs):
|
||||
yield item
|
||||
return
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
async for item in f(self, text, *args, **kwargs):
|
||||
yield item
|
||||
return
|
||||
|
||||
fn_called = False
|
||||
try:
|
||||
async with tracing_context(self, text):
|
||||
fn_called = True
|
||||
async for item in f(self, text, *args, **kwargs):
|
||||
yield item
|
||||
except Exception as e:
|
||||
if fn_called:
|
||||
raise
|
||||
logging.error(f"Error in TTS tracing (continuing without tracing): {e}")
|
||||
# If tracing fails, fall back to the original function
|
||||
async for item in f(self, text, *args, **kwargs):
|
||||
yield item
|
||||
|
||||
@@ -239,16 +253,18 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
|
||||
@functools.wraps(f)
|
||||
async def wrapper(self, text, *args, **kwargs):
|
||||
try:
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, text, *args, **kwargs)
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, text, *args, **kwargs)
|
||||
|
||||
fn_called = False
|
||||
try:
|
||||
async with tracing_context(self, text):
|
||||
fn_called = True
|
||||
return await f(self, text, *args, **kwargs)
|
||||
except Exception as e:
|
||||
if fn_called:
|
||||
raise
|
||||
logging.error(f"Error in TTS tracing (continuing without tracing): {e}")
|
||||
# If tracing fails, fall back to the original function
|
||||
return await f(self, text, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -281,17 +297,16 @@ def traced_stt(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
async def wrapper(self, transcript, is_final, language=None):
|
||||
try:
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, transcript, is_final, language)
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, transcript, is_final, language)
|
||||
|
||||
fn_called = False
|
||||
try:
|
||||
service_class_name = self.__class__.__name__
|
||||
span_name = "stt"
|
||||
|
||||
# Get the turn context first, then fall back to service context
|
||||
turn_context = get_current_turn_context()
|
||||
parent_context = turn_context or _get_parent_service_context(self)
|
||||
parent_context = _get_turn_context(self) or _get_parent_service_context(self)
|
||||
|
||||
# Create a new span as child of the turn span or service span
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
@@ -321,14 +336,16 @@ def traced_stt(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
)
|
||||
|
||||
# Call the original function
|
||||
fn_called = True
|
||||
return await f(self, transcript, is_final, language)
|
||||
except Exception as e:
|
||||
# Log any exception but don't disrupt the main flow
|
||||
logging.warning(f"Error in STT transcription tracing: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
if fn_called:
|
||||
raise
|
||||
logging.error(f"Error in STT tracing (continuing without tracing): {e}")
|
||||
# If tracing fails, fall back to the original function
|
||||
return await f(self, transcript, is_final, language)
|
||||
|
||||
return wrapper
|
||||
@@ -363,17 +380,16 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
async def wrapper(self, context, *args, **kwargs):
|
||||
try:
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, context, *args, **kwargs)
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, context, *args, **kwargs)
|
||||
|
||||
fn_called = False
|
||||
try:
|
||||
service_class_name = self.__class__.__name__
|
||||
span_name = "llm"
|
||||
|
||||
# Get the parent context - turn context if available, otherwise service context
|
||||
turn_context = get_current_turn_context()
|
||||
parent_context = turn_context or _get_parent_service_context(self)
|
||||
parent_context = _get_turn_context(self) or _get_parent_service_context(self)
|
||||
|
||||
# Create a new span as child of the turn span or service span
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
@@ -515,6 +531,7 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
# Don't raise - let the function execute anyway
|
||||
|
||||
# Run function with modified push_frame to capture the output
|
||||
fn_called = True
|
||||
result = await f(self, context, *args, **kwargs)
|
||||
|
||||
# Add aggregated output after function completes, if available
|
||||
@@ -540,8 +557,9 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
if ttfb is not None:
|
||||
current_span.set_attribute("metrics.ttfb", ttfb)
|
||||
except Exception as e:
|
||||
if fn_called:
|
||||
raise
|
||||
logging.error(f"Error in LLM tracing (continuing without tracing): {e}")
|
||||
# If tracing fails, fall back to the original function
|
||||
return await f(self, context, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -573,17 +591,16 @@ def traced_gemini_live(operation: str) -> Callable:
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await func(self, *args, **kwargs)
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
fn_called = False
|
||||
try:
|
||||
service_class_name = self.__class__.__name__
|
||||
span_name = f"{operation}"
|
||||
|
||||
# Get the parent context - turn context if available, otherwise service context
|
||||
turn_context = get_current_turn_context()
|
||||
parent_context = turn_context or _get_parent_service_context(self)
|
||||
parent_context = _get_turn_context(self) or _get_parent_service_context(self)
|
||||
|
||||
# Create a new span as child of the turn span or service span
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
@@ -840,6 +857,7 @@ def traced_gemini_live(operation: str) -> Callable:
|
||||
current_span.set_attribute("metrics.ttfb", ttfb)
|
||||
|
||||
# Run the original function
|
||||
fn_called = True
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
return result
|
||||
@@ -850,8 +868,9 @@ def traced_gemini_live(operation: str) -> Callable:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
if fn_called:
|
||||
raise
|
||||
logging.error(f"Error in Gemini Live tracing (continuing without tracing): {e}")
|
||||
# If tracing fails, fall back to the original function
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -880,17 +899,16 @@ def traced_openai_realtime(operation: str) -> Callable:
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await func(self, *args, **kwargs)
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
fn_called = False
|
||||
try:
|
||||
service_class_name = self.__class__.__name__
|
||||
span_name = f"{operation}"
|
||||
|
||||
# Get the parent context - turn context if available, otherwise service context
|
||||
turn_context = get_current_turn_context()
|
||||
parent_context = turn_context or _get_parent_service_context(self)
|
||||
parent_context = _get_turn_context(self) or _get_parent_service_context(self)
|
||||
|
||||
# Create a new span as child of the turn span or service span
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
@@ -1064,6 +1082,7 @@ def traced_openai_realtime(operation: str) -> Callable:
|
||||
current_span.set_attribute("metrics.ttfb", ttfb)
|
||||
|
||||
# Run the original function
|
||||
fn_called = True
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
return result
|
||||
@@ -1074,8 +1093,9 @@ def traced_openai_realtime(operation: str) -> Callable:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
if fn_called:
|
||||
raise
|
||||
logging.error(f"Error in OpenAI Realtime tracing (continuing without tracing): {e}")
|
||||
# If tracing fails, fall back to the original function
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
109
src/pipecat/utils/tracing/tracing_context.py
Normal file
109
src/pipecat/utils/tracing/tracing_context.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Pipeline-scoped tracing context for OpenTelemetry tracing in Pipecat.
|
||||
|
||||
This module provides a per-pipeline tracing context that holds the current
|
||||
conversation and turn span contexts. Each PipelineTask creates its own
|
||||
TracingContext, ensuring concurrent pipelines do not interfere with each other.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import SpanContext
|
||||
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
|
||||
if is_tracing_available():
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import NonRecordingSpan, SpanContext, set_span_in_context
|
||||
|
||||
|
||||
class TracingContext:
|
||||
"""Pipeline-scoped tracing context.
|
||||
|
||||
Holds the current conversation and turn span contexts for a single pipeline.
|
||||
Created by PipelineTask, passed to TurnTraceObserver (writer) and services
|
||||
(readers) via StartFrame.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the tracing context with empty state."""
|
||||
self._conversation_context: Optional["Context"] = None
|
||||
self._turn_context: Optional["Context"] = None
|
||||
self._conversation_id: Optional[str] = None
|
||||
|
||||
def set_conversation_context(
|
||||
self, span_context: Optional["SpanContext"], conversation_id: Optional[str] = None
|
||||
):
|
||||
"""Set the current conversation context.
|
||||
|
||||
Args:
|
||||
span_context: The span context for the current conversation or None to clear it.
|
||||
conversation_id: Optional ID for the conversation.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
return
|
||||
|
||||
self._conversation_id = conversation_id
|
||||
|
||||
if span_context:
|
||||
non_recording_span = NonRecordingSpan(span_context)
|
||||
self._conversation_context = set_span_in_context(non_recording_span)
|
||||
else:
|
||||
self._conversation_context = None
|
||||
|
||||
def get_conversation_context(self) -> Optional["Context"]:
|
||||
"""Get the OpenTelemetry context for the current conversation.
|
||||
|
||||
Returns:
|
||||
The current conversation context or None if not available.
|
||||
"""
|
||||
return self._conversation_context
|
||||
|
||||
def set_turn_context(self, span_context: Optional["SpanContext"]):
|
||||
"""Set the current turn context.
|
||||
|
||||
Args:
|
||||
span_context: The span context for the current turn or None to clear it.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
return
|
||||
|
||||
if span_context:
|
||||
non_recording_span = NonRecordingSpan(span_context)
|
||||
self._turn_context = set_span_in_context(non_recording_span)
|
||||
else:
|
||||
self._turn_context = None
|
||||
|
||||
def get_turn_context(self) -> Optional["Context"]:
|
||||
"""Get the OpenTelemetry context for the current turn.
|
||||
|
||||
Returns:
|
||||
The current turn context or None if not available.
|
||||
"""
|
||||
return self._turn_context
|
||||
|
||||
@property
|
||||
def conversation_id(self) -> Optional[str]:
|
||||
"""Get the ID for the current conversation.
|
||||
|
||||
Returns:
|
||||
The current conversation ID or None if not available.
|
||||
"""
|
||||
return self._conversation_id
|
||||
|
||||
@staticmethod
|
||||
def generate_conversation_id() -> str:
|
||||
"""Generate a new conversation ID.
|
||||
|
||||
Returns:
|
||||
A new randomly generated UUID string.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
@@ -1,81 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Turn context provider for OpenTelemetry tracing in Pipecat.
|
||||
|
||||
This module provides a singleton context provider that manages the current
|
||||
turn's tracing context, allowing services to create child spans that are
|
||||
properly associated with the conversation turn.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
# Import types for type checking only
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import SpanContext
|
||||
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
|
||||
if is_tracing_available():
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import NonRecordingSpan, SpanContext, set_span_in_context
|
||||
|
||||
|
||||
class TurnContextProvider:
|
||||
"""Provides access to the current turn's tracing context.
|
||||
|
||||
This is a singleton that services can use to get the current turn's
|
||||
span context to create child spans.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_current_turn_context: Optional["Context"] = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the singleton instance.
|
||||
|
||||
Returns:
|
||||
The singleton TurnContextProvider instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = TurnContextProvider()
|
||||
return cls._instance
|
||||
|
||||
def set_current_turn_context(self, span_context: Optional["SpanContext"]):
|
||||
"""Set the current turn context.
|
||||
|
||||
Args:
|
||||
span_context: The span context for the current turn or None to clear it.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
return
|
||||
|
||||
if span_context:
|
||||
# Create a non-recording span from the span context
|
||||
non_recording_span = NonRecordingSpan(span_context)
|
||||
self._current_turn_context = set_span_in_context(non_recording_span)
|
||||
else:
|
||||
self._current_turn_context = None
|
||||
|
||||
def get_current_turn_context(self) -> Optional["Context"]:
|
||||
"""Get the OpenTelemetry context for the current turn.
|
||||
|
||||
Returns:
|
||||
The current turn context or None if not available.
|
||||
"""
|
||||
return self._current_turn_context
|
||||
|
||||
|
||||
def get_current_turn_context() -> Optional["Context"]:
|
||||
"""Get the OpenTelemetry context for the current turn.
|
||||
|
||||
Returns:
|
||||
The current turn context or None if not available.
|
||||
"""
|
||||
provider = TurnContextProvider.get_instance()
|
||||
return provider.get_current_turn_context()
|
||||
@@ -19,9 +19,8 @@ from pipecat.frames.frames import StartFrame
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
|
||||
from pipecat.observers.user_bot_latency_observer import UserBotLatencyObserver
|
||||
from pipecat.utils.tracing.conversation_context_provider import ConversationContextProvider
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
from pipecat.utils.tracing.turn_context_provider import TurnContextProvider
|
||||
from pipecat.utils.tracing.tracing_context import TracingContext
|
||||
|
||||
# Import types for type checking only
|
||||
if TYPE_CHECKING:
|
||||
@@ -49,6 +48,7 @@ class TurnTraceObserver(BaseObserver):
|
||||
latency_tracker: UserBotLatencyObserver,
|
||||
conversation_id: Optional[str] = None,
|
||||
additional_span_attributes: Optional[dict] = None,
|
||||
tracing_context: Optional[TracingContext] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the turn trace observer.
|
||||
@@ -58,11 +58,13 @@ class TurnTraceObserver(BaseObserver):
|
||||
latency_tracker: The latency tracking observer for user-bot latency.
|
||||
conversation_id: Optional conversation ID for grouping turns.
|
||||
additional_span_attributes: Additional attributes to add to spans.
|
||||
tracing_context: Pipeline-scoped tracing context for span hierarchy.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._turn_tracker = turn_tracker
|
||||
self._latency_tracker = latency_tracker
|
||||
self._tracing_context = tracing_context or TracingContext()
|
||||
self._current_span: Optional["Span"] = None
|
||||
self._current_turn_number: int = 0
|
||||
self._trace_context_map: Dict[int, "SpanContext"] = {}
|
||||
@@ -123,9 +125,8 @@ class TurnTraceObserver(BaseObserver):
|
||||
return
|
||||
|
||||
# Generate a conversation ID if not provided
|
||||
context_provider = ConversationContextProvider.get_instance()
|
||||
if conversation_id is None:
|
||||
conversation_id = context_provider.generate_conversation_id()
|
||||
conversation_id = TracingContext.generate_conversation_id()
|
||||
logger.debug(f"Generated new conversation ID: {conversation_id}")
|
||||
|
||||
self._conversation_id = conversation_id
|
||||
@@ -140,8 +141,8 @@ class TurnTraceObserver(BaseObserver):
|
||||
for k, v in (self._additional_span_attributes or {}).items():
|
||||
self._conversation_span.set_attribute(k, v)
|
||||
|
||||
# Update the conversation context provider
|
||||
context_provider.set_current_conversation_context(
|
||||
# Update the tracing context
|
||||
self._tracing_context.set_conversation_context(
|
||||
self._conversation_span.get_span_context(), conversation_id
|
||||
)
|
||||
|
||||
@@ -161,9 +162,8 @@ class TurnTraceObserver(BaseObserver):
|
||||
self._current_span.end()
|
||||
self._current_span = None
|
||||
|
||||
# Clear the turn context provider
|
||||
context_provider = TurnContextProvider.get_instance()
|
||||
context_provider.set_current_turn_context(None)
|
||||
# Clear the turn context
|
||||
self._tracing_context.set_turn_context(None)
|
||||
|
||||
# Now end the conversation span if it exists
|
||||
if self._conversation_span:
|
||||
@@ -171,9 +171,8 @@ class TurnTraceObserver(BaseObserver):
|
||||
self._conversation_span.end()
|
||||
self._conversation_span = None
|
||||
|
||||
# Clear the context provider
|
||||
context_provider = ConversationContextProvider.get_instance()
|
||||
context_provider.set_current_conversation_context(None)
|
||||
# Clear the conversation context
|
||||
self._tracing_context.set_conversation_context(None)
|
||||
|
||||
logger.debug(f"Ended tracing for Conversation {self._conversation_id}")
|
||||
self._conversation_id = None
|
||||
@@ -189,8 +188,7 @@ class TurnTraceObserver(BaseObserver):
|
||||
# Get the parent context - conversation if available, otherwise use root context
|
||||
parent_context = None
|
||||
if self._conversation_span:
|
||||
context_provider = ConversationContextProvider.get_instance()
|
||||
parent_context = context_provider.get_current_conversation_context()
|
||||
parent_context = self._tracing_context.get_conversation_context()
|
||||
|
||||
# Create a new span for this turn
|
||||
self._current_span = self._tracer.start_span("turn", context=parent_context)
|
||||
@@ -207,9 +205,8 @@ class TurnTraceObserver(BaseObserver):
|
||||
# Store the span context so services can become children of this span
|
||||
self._trace_context_map[turn_number] = self._current_span.get_span_context()
|
||||
|
||||
# Update the context provider so services can access this span
|
||||
context_provider = TurnContextProvider.get_instance()
|
||||
context_provider.set_current_turn_context(self._current_span.get_span_context())
|
||||
# Update the tracing context so services can access this span
|
||||
self._tracing_context.set_turn_context(self._current_span.get_span_context())
|
||||
|
||||
logger.debug(f"Started tracing for Turn {turn_number}")
|
||||
|
||||
@@ -228,9 +225,8 @@ class TurnTraceObserver(BaseObserver):
|
||||
self._current_span.end()
|
||||
self._current_span = None
|
||||
|
||||
# Clear the context provider
|
||||
context_provider = TurnContextProvider.get_instance()
|
||||
context_provider.set_current_turn_context(None)
|
||||
# Clear the turn context
|
||||
self._tracing_context.set_turn_context(None)
|
||||
|
||||
logger.debug(f"Ended tracing for Turn {turn_number}")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -23,6 +24,13 @@ except ImportError:
|
||||
AIC_FILTER_MODULE = "pipecat.audio.filters.aic_filter"
|
||||
|
||||
|
||||
def _model_manager_ref_count(manager, key: str) -> int:
|
||||
"""Test helper: return reference count for a cache key (reads internal cache)."""
|
||||
with manager._lock:
|
||||
entry = manager._cache.get(key)
|
||||
return entry[1] if entry else 0
|
||||
|
||||
|
||||
class MockProcessor:
|
||||
"""A lightweight mock for AIC ProcessorAsync that mimics real behavior."""
|
||||
|
||||
@@ -99,10 +107,11 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Import AICFilter after confirming aic_sdk is available."""
|
||||
from pipecat.audio.filters.aic_filter import AICFilter
|
||||
from pipecat.audio.filters.aic_filter import AICFilter, AICModelManager
|
||||
from pipecat.frames.frames import FilterEnableFrame
|
||||
|
||||
cls.AICFilter = AICFilter
|
||||
cls.AICModelManager = AICModelManager
|
||||
cls.FilterEnableFrame = FilterEnableFrame
|
||||
|
||||
def setUp(self):
|
||||
@@ -122,13 +131,13 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def _start_filter_with_mocks(self, filter_instance, sample_rate=16000):
|
||||
"""Start a filter with mocked SDK components."""
|
||||
cache_key = "test-cache-key"
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, cache_key))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
await filter_instance.start(sample_rate)
|
||||
|
||||
@@ -171,37 +180,44 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
filter_instance = self._create_filter_with_mocks(model_id=None, model_path=model_path)
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_manager_cls.acquire = AsyncMock(
|
||||
return_value=(self.mock_model, "path:/tmp/test.aicmodel")
|
||||
)
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
mock_model_cls.from_file.assert_called_once_with(str(model_path))
|
||||
mock_manager_cls.acquire.assert_called_once()
|
||||
call_kw = mock_manager_cls.acquire.call_args[1]
|
||||
self.assertEqual(call_kw["model_path"], model_path)
|
||||
self.assertIsNone(call_kw["model_id"])
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
self.assertEqual(filter_instance._sample_rate, 16000)
|
||||
self.assertEqual(filter_instance._frames_per_block, 160)
|
||||
|
||||
async def test_start_with_model_id_downloads(self):
|
||||
"""Test starting filter with model_id triggers download."""
|
||||
"""Test starting filter with model_id uses manager (download happens in manager)."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
|
||||
mock_manager_cls.acquire = AsyncMock(
|
||||
return_value=(self.mock_model, "id:test-model:/custom/cache")
|
||||
)
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
mock_model_cls.download_async.assert_called_once()
|
||||
mock_model_cls.from_file.assert_called_once()
|
||||
mock_manager_cls.acquire.assert_called_once()
|
||||
call_kw = mock_manager_cls.acquire.call_args[1]
|
||||
self.assertEqual(call_kw["model_id"], "test-model")
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
|
||||
async def test_start_creates_processor(self):
|
||||
@@ -209,14 +225,13 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(
|
||||
f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor
|
||||
) as mock_processor_cls,
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-cache-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
@@ -241,17 +256,21 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(bypass_params[-1][1], 0.0)
|
||||
|
||||
async def test_stop_cleans_up_resources(self):
|
||||
"""Test that stop properly cleans up resources."""
|
||||
"""Test that stop properly cleans up resources and releases model reference."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
await self._start_filter_with_mocks(filter_instance)
|
||||
cache_key = filter_instance._model_cache_key
|
||||
|
||||
await filter_instance.stop()
|
||||
with patch(f"{AIC_FILTER_MODULE}.AICModelManager.release") as mock_release:
|
||||
await filter_instance.stop()
|
||||
|
||||
mock_release.assert_called_once_with(cache_key)
|
||||
self.assertTrue(self.mock_processor.processor_ctx.reset_called)
|
||||
self.assertIsNone(filter_instance._processor)
|
||||
self.assertIsNone(filter_instance._processor_ctx)
|
||||
self.assertIsNone(filter_instance._vad_ctx)
|
||||
self.assertIsNone(filter_instance._model)
|
||||
self.assertIsNone(filter_instance._model_cache_key)
|
||||
self.assertFalse(filter_instance._aic_ready)
|
||||
|
||||
async def test_stop_without_start(self):
|
||||
@@ -261,6 +280,177 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
# Should not raise
|
||||
await filter_instance.stop()
|
||||
|
||||
async def test_model_manager_reference_count(self):
|
||||
"""Test that AICModelManager reference count increments and decrements correctly."""
|
||||
model_path = Path("/tmp/refcount-test.aicmodel")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.from_file.return_value = mock_model
|
||||
|
||||
# Acquire first reference
|
||||
model1, key = await manager.acquire(model_path=model_path)
|
||||
self.assertEqual(model1, mock_model)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
|
||||
# Acquire second reference (same key, cached)
|
||||
model2, key2 = await manager.acquire(model_path=model_path)
|
||||
self.assertIs(model2, model1)
|
||||
self.assertEqual(key2, key)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 2)
|
||||
|
||||
# Release one reference
|
||||
manager.release(key)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
|
||||
# Release last reference (model evicted from cache)
|
||||
manager.release(key)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 0)
|
||||
|
||||
async def test_model_manager_concurrent_load_deduplication(self):
|
||||
"""Test that concurrent acquire calls for the same key share a single load task."""
|
||||
model_path = Path("/tmp/concurrent-load-test.aicmodel")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
load_count = 0
|
||||
|
||||
def from_file_once(path):
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
time.sleep(0.02) # yield so other acquire callers can hit _loading and await same task
|
||||
return mock_model
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.from_file.side_effect = from_file_once
|
||||
|
||||
# Start several acquire calls concurrently before any completes
|
||||
results = await asyncio.gather(
|
||||
manager.acquire(model_path=model_path),
|
||||
manager.acquire(model_path=model_path),
|
||||
manager.acquire(model_path=model_path),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
load_count, 1, "Model.from_file should be called once for concurrent callers"
|
||||
)
|
||||
model1, key1 = results[0]
|
||||
model2, key2 = results[1]
|
||||
model3, key3 = results[2]
|
||||
self.assertIs(model1, mock_model)
|
||||
self.assertIs(model2, mock_model)
|
||||
self.assertIs(model3, mock_model)
|
||||
self.assertEqual(key1, key2)
|
||||
self.assertEqual(key2, key3)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key1), 3)
|
||||
|
||||
# Release all references
|
||||
manager.release(key1)
|
||||
manager.release(key1)
|
||||
manager.release(key1)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key1), 0)
|
||||
|
||||
async def test_load_model_from_file_invalid_args_raises(self):
|
||||
"""Test _load_model_from_file defensive else: raises ValueError."""
|
||||
manager = self.AICModelManager
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
await manager._load_model_from_file(
|
||||
"key",
|
||||
model_path=None,
|
||||
model_id=None,
|
||||
model_download_dir=None,
|
||||
)
|
||||
self.assertIn("Unexpected", str(ctx.exception))
|
||||
|
||||
async def test_model_manager_acquire_by_model_id_hits_download_path(self):
|
||||
"""Test acquire with model_id runs download path in _load_model_from_file."""
|
||||
model_id = "test-model-id"
|
||||
model_download_dir = Path("/tmp/aic-downloads")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.download_async = AsyncMock(
|
||||
return_value="/tmp/aic-downloads/model.aicmodel"
|
||||
)
|
||||
mock_model_cls.from_file.return_value = mock_model
|
||||
|
||||
model, key = await manager.acquire(
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
|
||||
mock_model_cls.download_async.assert_called_once()
|
||||
mock_model_cls.from_file.assert_called_once_with("/tmp/aic-downloads/model.aicmodel")
|
||||
self.assertIs(model, mock_model)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
manager.release(key)
|
||||
|
||||
def test_get_cache_key_invalid_raises(self):
|
||||
"""Test _get_cache_key raises ValueError for invalid args."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.AICModelManager._get_cache_key(model_path=None, model_id=None)
|
||||
self.assertIn("model_path", str(ctx.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as ctx2:
|
||||
self.AICModelManager._get_cache_key(
|
||||
model_path=None,
|
||||
model_id="x",
|
||||
model_download_dir=None,
|
||||
)
|
||||
self.assertIn("model_download_dir", str(ctx2.exception))
|
||||
|
||||
async def test_start_processor_init_failure(self):
|
||||
"""Test start() when ProcessorAsync raises: exception logged, _aic_ready False."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(
|
||||
f"{AIC_FILTER_MODULE}.ProcessorAsync",
|
||||
side_effect=RuntimeError("SDK init failed"),
|
||||
),
|
||||
):
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
self.assertIsNone(filter_instance._processor)
|
||||
self.assertFalse(filter_instance._aic_ready)
|
||||
|
||||
async def test_start_parameter_fixed_error_logged(self):
|
||||
"""Test start() when set_parameter raises ParameterFixedError: logged, no raise."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
self.mock_processor.processor_ctx.set_parameter = MagicMock(
|
||||
side_effect=aic_sdk.ParameterFixedError("fixed")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
|
||||
async def test_process_frame_set_parameter_exception_logged(self):
|
||||
"""Test process_frame when set_parameter raises: exception logged, no raise."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
await self._start_filter_with_mocks(filter_instance)
|
||||
filter_instance._processor_ctx.set_parameter = MagicMock(
|
||||
side_effect=ValueError("param error")
|
||||
)
|
||||
|
||||
await filter_instance.process_frame(self.FilterEnableFrame(enable=True))
|
||||
|
||||
self.assertFalse(filter_instance._bypass)
|
||||
|
||||
async def test_process_frame_enable(self):
|
||||
"""Test processing FilterEnableFrame to enable filtering."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
@@ -24,7 +24,9 @@ from pipecat.frames.frames import (
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserMuteStartedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
@@ -40,7 +42,11 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.user_mute import FirstSpeechUserMuteStrategy, FunctionCallUserMuteStrategy
|
||||
from pipecat.turns.user_mute import (
|
||||
FirstSpeechUserMuteStrategy,
|
||||
FunctionCallUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import SpeechTimeoutUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
@@ -386,6 +392,42 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIsNone(strategy) # strategy is None for end/cancel
|
||||
self.assertEqual(message.content, "Hello!")
|
||||
|
||||
async def test_start_frame_before_mute_event(self):
|
||||
"""StartFrame must reach downstream before mute events are broadcast.
|
||||
|
||||
With MuteUntilFirstBotCompleteUserMuteStrategy, the mute logic should
|
||||
not run on control frames (StartFrame, EndFrame, CancelFrame). This
|
||||
ensures StartFrame reaches downstream processors before
|
||||
UserMuteStartedFrame is broadcast.
|
||||
|
||||
The default TurnAnalyzerUserTurnStopStrategy broadcasts a
|
||||
SpeechControlParamsFrame when it processes StartFrame, which gets
|
||||
re-queued to the aggregator. That non-control frame legitimately
|
||||
triggers the mute state change, so UserMuteStartedFrame follows
|
||||
StartFrame — but crucially, after it.
|
||||
"""
|
||||
context = LLMContext()
|
||||
|
||||
user_aggregator = LLMUserAggregator(
|
||||
context,
|
||||
params=LLMUserAggregatorParams(
|
||||
user_mute_strategies=[MuteUntilFirstBotCompleteUserMuteStrategy()],
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline([user_aggregator])
|
||||
|
||||
# run_test internally sends StartFrame via PipelineRunner. With
|
||||
# ignore_start=False we can verify ordering: StartFrame must arrive
|
||||
# before UserMuteStartedFrame. Before the fix, UserMuteStartedFrame
|
||||
# was broadcast before StartFrame reached downstream processors.
|
||||
(down_frames, _) = await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[],
|
||||
expected_down_frames=[StartFrame, UserMuteStartedFrame],
|
||||
ignore_start=False,
|
||||
)
|
||||
|
||||
|
||||
class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_empty(self):
|
||||
|
||||
@@ -25,7 +25,6 @@ from pipecat.frames.frames import (
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.frame_processor import (
|
||||
INTERRUPTION_COMPLETION_TIMEOUT,
|
||||
FrameDirection,
|
||||
FrameProcessor,
|
||||
)
|
||||
@@ -521,7 +520,7 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
# Complete after the timeout so the warning fires
|
||||
# but the test doesn't hang.
|
||||
async def delayed_complete():
|
||||
await asyncio.sleep(INTERRUPTION_COMPLETION_TIMEOUT + 1.0)
|
||||
await asyncio.sleep(1.0)
|
||||
frame.complete()
|
||||
|
||||
asyncio.create_task(delayed_complete())
|
||||
@@ -532,7 +531,7 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_interruption_task_frame_and_wait(timeout=0.5)
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -223,3 +223,77 @@ async def test_openai_llm_emits_error_frame_on_exception():
|
||||
assert "Error during completion" in pushed_errors[0]["error_msg"]
|
||||
assert "API Error" in pushed_errors[0]["error_msg"]
|
||||
assert isinstance(pushed_errors[0]["exception"], RuntimeError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_llm_async_iterator_closed_on_stream_end():
|
||||
"""Test that the async iterator is explicitly closed after stream consumption.
|
||||
|
||||
This prevents uvloop's broken asyncgen finalizer from firing on Python 3.12+
|
||||
when async generators are garbage-collected without explicit cleanup.
|
||||
See MagicStack/uvloop#699.
|
||||
"""
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
service = OpenAILLMService(model="gpt-4")
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Track if the iterator's aclose was called
|
||||
iterator_aclosed = False
|
||||
stream_closed = False
|
||||
|
||||
class MockAsyncIterator:
|
||||
"""Mock async iterator that tracks aclose() calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.iteration_count = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
self.iteration_count += 1
|
||||
if self.iteration_count > 2:
|
||||
raise StopAsyncIteration()
|
||||
# Return a minimal chunk
|
||||
mock_chunk = AsyncMock()
|
||||
mock_chunk.usage = None
|
||||
mock_chunk.model = None
|
||||
mock_chunk.choices = []
|
||||
return mock_chunk
|
||||
|
||||
async def aclose(self):
|
||||
nonlocal iterator_aclosed
|
||||
iterator_aclosed = True
|
||||
|
||||
class MockAsyncStream:
|
||||
"""Mock stream whose __aiter__ returns a separate iterator object."""
|
||||
|
||||
def __init__(self, iterator):
|
||||
self._iterator = iterator
|
||||
|
||||
def __aiter__(self):
|
||||
return self._iterator
|
||||
|
||||
async def close(self):
|
||||
nonlocal stream_closed
|
||||
stream_closed = True
|
||||
|
||||
mock_iterator = MockAsyncIterator()
|
||||
mock_stream = MockAsyncStream(mock_iterator)
|
||||
|
||||
service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream)
|
||||
service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream)
|
||||
service.start_ttfb_metrics = AsyncMock()
|
||||
service.stop_ttfb_metrics = AsyncMock()
|
||||
service.start_llm_usage_metrics = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
await service._process_context(context)
|
||||
|
||||
# Verify the iterator was explicitly closed (prevents uvloop crash)
|
||||
assert iterator_aclosed, "Async iterator should be explicitly closed"
|
||||
# Verify the stream was also closed (releases HTTP resources)
|
||||
assert stream_closed, "Stream should be closed to release HTTP resources"
|
||||
|
||||
127
tests/test_tracing_context.py
Normal file
127
tests/test_tracing_context.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
HAS_OPENTELEMETRY = True
|
||||
except ImportError:
|
||||
HAS_OPENTELEMETRY = False
|
||||
|
||||
from pipecat.utils.tracing.tracing_context import TracingContext
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_OPENTELEMETRY, "opentelemetry not installed")
|
||||
class TestTracingContext(unittest.TestCase):
|
||||
"""Tests for TracingContext."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up a tracer provider for generating span contexts."""
|
||||
cls._provider = TracerProvider()
|
||||
cls._tracer = cls._provider.get_tracer("test")
|
||||
|
||||
def test_initial_state_is_empty(self):
|
||||
"""Test that a new TracingContext starts with no context set."""
|
||||
ctx = TracingContext()
|
||||
self.assertIsNone(ctx.get_conversation_context())
|
||||
self.assertIsNone(ctx.get_turn_context())
|
||||
self.assertIsNone(ctx.conversation_id)
|
||||
|
||||
def test_set_and_get_conversation_context(self):
|
||||
"""Test setting and retrieving conversation context."""
|
||||
ctx = TracingContext()
|
||||
span = self._tracer.start_span("conv")
|
||||
span_context = span.get_span_context()
|
||||
|
||||
ctx.set_conversation_context(span_context, "conv-123")
|
||||
|
||||
self.assertIsNotNone(ctx.get_conversation_context())
|
||||
self.assertEqual(ctx.conversation_id, "conv-123")
|
||||
span.end()
|
||||
|
||||
def test_clear_conversation_context(self):
|
||||
"""Test clearing conversation context by passing None."""
|
||||
ctx = TracingContext()
|
||||
span = self._tracer.start_span("conv")
|
||||
|
||||
ctx.set_conversation_context(span.get_span_context(), "conv-123")
|
||||
self.assertIsNotNone(ctx.get_conversation_context())
|
||||
|
||||
ctx.set_conversation_context(None)
|
||||
self.assertIsNone(ctx.get_conversation_context())
|
||||
self.assertIsNone(ctx.conversation_id)
|
||||
span.end()
|
||||
|
||||
def test_set_and_get_turn_context(self):
|
||||
"""Test setting and retrieving turn context."""
|
||||
ctx = TracingContext()
|
||||
span = self._tracer.start_span("turn")
|
||||
span_context = span.get_span_context()
|
||||
|
||||
ctx.set_turn_context(span_context)
|
||||
|
||||
self.assertIsNotNone(ctx.get_turn_context())
|
||||
span.end()
|
||||
|
||||
def test_clear_turn_context(self):
|
||||
"""Test clearing turn context by passing None."""
|
||||
ctx = TracingContext()
|
||||
span = self._tracer.start_span("turn")
|
||||
|
||||
ctx.set_turn_context(span.get_span_context())
|
||||
self.assertIsNotNone(ctx.get_turn_context())
|
||||
|
||||
ctx.set_turn_context(None)
|
||||
self.assertIsNone(ctx.get_turn_context())
|
||||
span.end()
|
||||
|
||||
def test_generate_conversation_id(self):
|
||||
"""Test that generated conversation IDs are unique UUIDs."""
|
||||
id1 = TracingContext.generate_conversation_id()
|
||||
id2 = TracingContext.generate_conversation_id()
|
||||
self.assertIsInstance(id1, str)
|
||||
self.assertNotEqual(id1, id2)
|
||||
|
||||
def test_instances_are_isolated(self):
|
||||
"""Test that two TracingContext instances do not share state."""
|
||||
ctx_a = TracingContext()
|
||||
ctx_b = TracingContext()
|
||||
|
||||
span = self._tracer.start_span("turn")
|
||||
|
||||
ctx_a.set_turn_context(span.get_span_context())
|
||||
ctx_a.set_conversation_context(span.get_span_context(), "conv-a")
|
||||
|
||||
# ctx_b should still be empty
|
||||
self.assertIsNone(ctx_b.get_turn_context())
|
||||
self.assertIsNone(ctx_b.get_conversation_context())
|
||||
self.assertIsNone(ctx_b.conversation_id)
|
||||
span.end()
|
||||
|
||||
def test_conversation_and_turn_are_independent(self):
|
||||
"""Test that clearing turn context does not affect conversation context."""
|
||||
ctx = TracingContext()
|
||||
conv_span = self._tracer.start_span("conv")
|
||||
turn_span = self._tracer.start_span("turn")
|
||||
|
||||
ctx.set_conversation_context(conv_span.get_span_context(), "conv-1")
|
||||
ctx.set_turn_context(turn_span.get_span_context())
|
||||
|
||||
# Clear turn but conversation should remain
|
||||
ctx.set_turn_context(None)
|
||||
self.assertIsNone(ctx.get_turn_context())
|
||||
self.assertIsNotNone(ctx.get_conversation_context())
|
||||
self.assertEqual(ctx.conversation_id, "conv-1")
|
||||
|
||||
conv_span.end()
|
||||
turn_span.end()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
505
tests/test_turn_trace_observer.py
Normal file
505
tests/test_turn_trace_observer.py
Normal file
@@ -0,0 +1,505 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult
|
||||
|
||||
HAS_OPENTELEMETRY = True
|
||||
except ImportError:
|
||||
HAS_OPENTELEMETRY = False
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
|
||||
from pipecat.observers.user_bot_latency_observer import UserBotLatencyObserver
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.utils.tracing.tracing_context import TracingContext
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
|
||||
if HAS_OPENTELEMETRY:
|
||||
|
||||
class _InMemorySpanExporter(SpanExporter):
|
||||
"""Simple in-memory span exporter for testing."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the exporter."""
|
||||
self._spans = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def export(self, spans):
|
||||
"""Export spans to memory."""
|
||||
with self._lock:
|
||||
self._spans.extend(spans)
|
||||
return SpanExportResult.SUCCESS
|
||||
|
||||
def get_finished_spans(self):
|
||||
"""Return collected spans."""
|
||||
with self._lock:
|
||||
return list(self._spans)
|
||||
|
||||
def clear(self):
|
||||
"""Clear collected spans."""
|
||||
with self._lock:
|
||||
self._spans.clear()
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_OPENTELEMETRY, "opentelemetry not installed")
|
||||
class TestTurnTraceObserver(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for TurnTraceObserver."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up a fresh provider and exporter for each test.
|
||||
|
||||
We create a dedicated TracerProvider per test and inject its tracer
|
||||
directly into the observer, avoiding the global provider singleton.
|
||||
"""
|
||||
self._exporter = _InMemorySpanExporter()
|
||||
self._provider = TracerProvider()
|
||||
self._provider.add_span_processor(SimpleSpanProcessor(self._exporter))
|
||||
self._tracer = self._provider.get_tracer("pipecat.turn")
|
||||
|
||||
def tearDown(self):
|
||||
"""Shut down the provider to flush spans."""
|
||||
self._provider.shutdown()
|
||||
|
||||
def _create_observers(self, conversation_id=None, tracing_context=None):
|
||||
"""Create a standard set of turn/trace observers.
|
||||
|
||||
Args:
|
||||
conversation_id: Optional conversation ID.
|
||||
tracing_context: Optional TracingContext instance.
|
||||
|
||||
Returns:
|
||||
Tuple of (turn_tracker, latency_tracker, trace_observer, tracing_context).
|
||||
"""
|
||||
tracing_context = tracing_context or TracingContext()
|
||||
turn_tracker = TurnTrackingObserver(turn_end_timeout_secs=0.2)
|
||||
latency_tracker = UserBotLatencyObserver()
|
||||
trace_observer = TurnTraceObserver(
|
||||
turn_tracker,
|
||||
latency_tracker=latency_tracker,
|
||||
conversation_id=conversation_id,
|
||||
tracing_context=tracing_context,
|
||||
)
|
||||
# Inject the test tracer so spans go to our in-memory exporter
|
||||
trace_observer._tracer = self._tracer
|
||||
return turn_tracker, latency_tracker, trace_observer, tracing_context
|
||||
|
||||
def _all_observers(self, trace_observer):
|
||||
"""Return the list of observers needed for run_test."""
|
||||
return [trace_observer._turn_tracker, trace_observer._latency_tracker, trace_observer]
|
||||
|
||||
def _get_spans_by_name(self, name):
|
||||
"""Return finished spans with the given name."""
|
||||
return [s for s in self._exporter.get_finished_spans() if s.name == name]
|
||||
|
||||
async def test_conversation_span_created_on_start_frame(self):
|
||||
"""Test that a conversation span is created when StartFrame is observed."""
|
||||
_, _, trace_observer, _ = self._create_observers(conversation_id="test-conv")
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
# End conversation to flush the conversation span (normally done by PipelineTask._cleanup)
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
conv_spans = self._get_spans_by_name("conversation")
|
||||
self.assertEqual(len(conv_spans), 1)
|
||||
self.assertEqual(conv_spans[0].attributes["conversation.id"], "test-conv")
|
||||
self.assertEqual(conv_spans[0].attributes["conversation.type"], "voice")
|
||||
|
||||
async def test_turn_spans_created_for_each_turn(self):
|
||||
"""Test that a turn span is created for each conversation turn."""
|
||||
_, _, trace_observer, _ = self._create_observers()
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
# Turn 1
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.05),
|
||||
# Turn 2
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
turn_spans = self._get_spans_by_name("turn")
|
||||
self.assertEqual(len(turn_spans), 2)
|
||||
turn_numbers = {s.attributes["turn.number"] for s in turn_spans}
|
||||
self.assertEqual(turn_numbers, {1, 2})
|
||||
|
||||
async def test_turn_spans_are_children_of_conversation(self):
|
||||
"""Test that turn spans are parented under the conversation span."""
|
||||
_, _, trace_observer, _ = self._create_observers()
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
# End conversation to flush the conversation span
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
conv_spans = self._get_spans_by_name("conversation")
|
||||
turn_spans = self._get_spans_by_name("turn")
|
||||
self.assertEqual(len(conv_spans), 1)
|
||||
self.assertEqual(len(turn_spans), 1)
|
||||
|
||||
# Turn span's parent should be the conversation span
|
||||
conv_span_id = conv_spans[0].context.span_id
|
||||
turn_parent_id = turn_spans[0].parent.span_id
|
||||
self.assertEqual(turn_parent_id, conv_span_id)
|
||||
|
||||
async def test_interrupted_turn_marked(self):
|
||||
"""Test that an interrupted turn span has was_interrupted=True."""
|
||||
_, _, trace_observer, _ = self._create_observers()
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
# User interrupts
|
||||
UserStartedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
# End conversation to flush remaining spans
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
turn_spans = self._get_spans_by_name("turn")
|
||||
self.assertGreaterEqual(len(turn_spans), 1)
|
||||
# First turn should be interrupted
|
||||
interrupted_turns = [s for s in turn_spans if s.attributes.get("turn.was_interrupted")]
|
||||
self.assertGreaterEqual(len(interrupted_turns), 1)
|
||||
|
||||
async def test_tracing_context_updated_during_turn(self):
|
||||
"""Test that TracingContext is populated during a turn and cleared after."""
|
||||
tracing_ctx = TracingContext()
|
||||
_, _, trace_observer, _ = self._create_observers(tracing_context=tracing_ctx)
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
# After the turn ends, turn context should be cleared
|
||||
self.assertIsNone(tracing_ctx.get_turn_context())
|
||||
|
||||
async def test_tracing_context_cleared_after_conversation_end(self):
|
||||
"""Test that TracingContext is cleared when conversation tracing ends."""
|
||||
tracing_ctx = TracingContext()
|
||||
_, _, trace_observer, _ = self._create_observers(tracing_context=tracing_ctx)
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
# Manually end conversation tracing (as PipelineTask._cleanup does)
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
self.assertIsNone(tracing_ctx.get_conversation_context())
|
||||
self.assertIsNone(tracing_ctx.get_turn_context())
|
||||
self.assertIsNone(tracing_ctx.conversation_id)
|
||||
|
||||
async def test_additional_span_attributes(self):
|
||||
"""Test that additional span attributes are added to the conversation span."""
|
||||
extra_attrs = {"deployment.id": "abc-123", "customer.tier": "premium"}
|
||||
tracing_ctx = TracingContext()
|
||||
turn_tracker = TurnTrackingObserver(turn_end_timeout_secs=0.2)
|
||||
latency_tracker = UserBotLatencyObserver()
|
||||
trace_observer = TurnTraceObserver(
|
||||
turn_tracker,
|
||||
latency_tracker=latency_tracker,
|
||||
additional_span_attributes=extra_attrs,
|
||||
tracing_context=tracing_ctx,
|
||||
)
|
||||
trace_observer._tracer = self._tracer
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=[turn_tracker, latency_tracker, trace_observer],
|
||||
)
|
||||
|
||||
# End conversation to flush the conversation span
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
conv_spans = self._get_spans_by_name("conversation")
|
||||
self.assertEqual(len(conv_spans), 1)
|
||||
self.assertEqual(conv_spans[0].attributes["deployment.id"], "abc-123")
|
||||
self.assertEqual(conv_spans[0].attributes["customer.tier"], "premium")
|
||||
|
||||
async def test_concurrent_pipelines_are_isolated(self):
|
||||
"""Test that two pipelines with separate TracingContexts don't interfere."""
|
||||
tracing_ctx_a = TracingContext()
|
||||
tracing_ctx_b = TracingContext()
|
||||
|
||||
_, _, trace_observer_a, _ = self._create_observers(
|
||||
conversation_id="conv-a", tracing_context=tracing_ctx_a
|
||||
)
|
||||
_, _, trace_observer_b, _ = self._create_observers(
|
||||
conversation_id="conv-b", tracing_context=tracing_ctx_b
|
||||
)
|
||||
|
||||
processor_a = IdentityFilter()
|
||||
processor_b = IdentityFilter()
|
||||
|
||||
frames = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
# Run both pipelines concurrently
|
||||
await asyncio.gather(
|
||||
run_test(
|
||||
processor_a,
|
||||
frames_to_send=frames,
|
||||
expected_down_frames=expected,
|
||||
observers=self._all_observers(trace_observer_a),
|
||||
),
|
||||
run_test(
|
||||
processor_b,
|
||||
frames_to_send=frames,
|
||||
expected_down_frames=expected,
|
||||
observers=self._all_observers(trace_observer_b),
|
||||
),
|
||||
)
|
||||
|
||||
# End both conversations to flush spans
|
||||
trace_observer_a.end_conversation_tracing()
|
||||
trace_observer_b.end_conversation_tracing()
|
||||
|
||||
# Each TracingContext should have its own conversation ID
|
||||
conv_spans = self._get_spans_by_name("conversation")
|
||||
conv_ids = {s.attributes["conversation.id"] for s in conv_spans}
|
||||
self.assertEqual(conv_ids, {"conv-a", "conv-b"})
|
||||
|
||||
# Turn spans should be children of their own conversation span, not cross-linked
|
||||
turn_spans = self._get_spans_by_name("turn")
|
||||
conv_span_map = {s.context.span_id: s.attributes["conversation.id"] for s in conv_spans}
|
||||
for turn_span in turn_spans:
|
||||
parent_id = turn_span.parent.span_id
|
||||
turn_conv_id = turn_span.attributes["conversation.id"]
|
||||
parent_conv_id = conv_span_map[parent_id]
|
||||
self.assertEqual(
|
||||
turn_conv_id,
|
||||
parent_conv_id,
|
||||
f"Turn span for {turn_conv_id} parented under {parent_conv_id}",
|
||||
)
|
||||
|
||||
async def test_end_conversation_closes_active_turn(self):
|
||||
"""Test that end_conversation_tracing closes any active turn span."""
|
||||
_, _, trace_observer, _ = self._create_observers()
|
||||
|
||||
# Manually start conversation and a turn
|
||||
trace_observer.start_conversation_tracing("conv-end-test")
|
||||
await trace_observer._handle_turn_started(1)
|
||||
|
||||
self.assertIsNotNone(trace_observer._current_span)
|
||||
self.assertIsNotNone(trace_observer._conversation_span)
|
||||
|
||||
# End conversation — should close both turn and conversation
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
self.assertIsNone(trace_observer._current_span)
|
||||
self.assertIsNone(trace_observer._conversation_span)
|
||||
|
||||
# Check span attributes
|
||||
turn_spans = self._get_spans_by_name("turn")
|
||||
self.assertEqual(len(turn_spans), 1)
|
||||
self.assertTrue(turn_spans[0].attributes["turn.was_interrupted"])
|
||||
self.assertTrue(turn_spans[0].attributes["turn.ended_by_conversation_end"])
|
||||
|
||||
async def test_conversation_id_auto_generated(self):
|
||||
"""Test that a conversation ID is auto-generated when none is provided."""
|
||||
_, _, trace_observer, _ = self._create_observers(conversation_id=None)
|
||||
processor = IdentityFilter()
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=self._all_observers(trace_observer),
|
||||
)
|
||||
|
||||
# End conversation to flush the conversation span
|
||||
trace_observer.end_conversation_tracing()
|
||||
|
||||
conv_spans = self._get_spans_by_name("conversation")
|
||||
self.assertEqual(len(conv_spans), 1)
|
||||
# Should have an auto-generated UUID as conversation.id
|
||||
conv_id = conv_spans[0].attributes["conversation.id"]
|
||||
self.assertIsNotNone(conv_id)
|
||||
self.assertGreater(len(conv_id), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -6,12 +6,14 @@
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
UserSpeakingFrame,
|
||||
UserIdleTimeoutUpdateFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.user_idle_controller import UserIdleController
|
||||
@@ -25,8 +27,8 @@ class TestUserIdleController(unittest.IsolatedAsyncioTestCase):
|
||||
self.task_manager = TaskManager()
|
||||
self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
|
||||
async def test_basic_idle_detection(self):
|
||||
"""Test that idle event is triggered after timeout when no activity."""
|
||||
async def test_idle_after_bot_stops_speaking(self):
|
||||
"""Test that idle event fires after BotStoppedSpeakingFrame + timeout."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
@@ -37,18 +39,16 @@ class TestUserIdleController(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Start conversation
|
||||
await controller.process_frame(UserStartedSpeakingFrame())
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
|
||||
# Wait for idle timeout
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertTrue(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_user_speaking_resets_idle_timer(self):
|
||||
"""Test that continuous UserSpeakingFrame frames reset the idle timer."""
|
||||
async def test_user_speaking_cancels_timer(self):
|
||||
"""Test that UserStartedSpeakingFrame cancels the idle timer."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
@@ -59,20 +59,18 @@ class TestUserIdleController(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Start conversation
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.3)
|
||||
await controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
# Send UserSpeakingFrame continuously to reset timer
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.5) # 50% of timeout period
|
||||
await controller.process_frame(UserSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_bot_speaking_resets_idle_timer(self):
|
||||
"""Test that BotSpeakingFrame frames reset the idle timer."""
|
||||
async def test_bot_speaking_cancels_timer(self):
|
||||
"""Test that BotStartedSpeakingFrame cancels the idle timer."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
@@ -83,102 +81,61 @@ class TestUserIdleController(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Start conversation
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.3)
|
||||
await controller.process_frame(BotStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_no_idle_before_bot_speaks(self):
|
||||
"""Test that idle does not fire if no BotStoppedSpeakingFrame is received."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Wait without any frames
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_interruption_no_false_trigger(self):
|
||||
"""Test that BotStoppedSpeakingFrame during a user turn does not start the timer."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# User starts speaking (interruption)
|
||||
await controller.process_frame(UserStartedSpeakingFrame())
|
||||
# Bot stops speaking due to interruption
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
|
||||
# Bot speaking should reset timer
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.6) # 60% of timeout
|
||||
await controller.process_frame(BotSpeakingFrame())
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_function_call_prevents_idle(self):
|
||||
"""Test that function calls in progress prevent idle event."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Start conversation
|
||||
await controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
# Start function call
|
||||
await controller.process_frame(FunctionCallsStartedFrame(function_calls=[]))
|
||||
|
||||
# Wait longer than idle timeout
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
# Should not trigger idle because function call is in progress
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
# Complete function call
|
||||
await controller.process_frame(
|
||||
FunctionCallResultFrame(
|
||||
function_name="test",
|
||||
tool_call_id="123",
|
||||
arguments={},
|
||||
result=None,
|
||||
run_llm=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Now idle should trigger
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
self.assertTrue(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_no_idle_before_conversation_starts(self):
|
||||
"""Test that idle monitoring doesn't start before first conversation activity."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Wait without starting conversation
|
||||
# Wait - timer should NOT have started because user turn is in progress
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_idle_starts_with_bot_speaking(self):
|
||||
"""Test that monitoring starts with BotSpeakingFrame, not just user speech."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Start conversation with bot speaking
|
||||
await controller.process_frame(BotSpeakingFrame())
|
||||
|
||||
# Wait for idle timeout
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertTrue(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_multiple_idle_events(self):
|
||||
"""Test that idle event can trigger multiple times."""
|
||||
async def test_idle_cycle(self):
|
||||
"""Test that idle fires, then can fire again after another bot speaking cycle."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
@@ -189,29 +146,175 @@ class TestUserIdleController(unittest.IsolatedAsyncioTestCase):
|
||||
nonlocal idle_count
|
||||
idle_count += 1
|
||||
|
||||
# Start conversation
|
||||
await controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
# First idle
|
||||
# First cycle: bot stops → idle fires
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
first_count = idle_count
|
||||
self.assertGreaterEqual(first_count, 1)
|
||||
self.assertEqual(idle_count, 1)
|
||||
|
||||
# Second idle
|
||||
# Second cycle: bot starts → bot stops → idle fires again
|
||||
await controller.process_frame(BotStartedSpeakingFrame())
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
second_count = idle_count
|
||||
self.assertGreater(second_count, first_count)
|
||||
self.assertEqual(idle_count, 2)
|
||||
|
||||
# User activity resets timer
|
||||
await controller.process_frame(UserSpeakingFrame())
|
||||
await controller.cleanup()
|
||||
|
||||
# Give a moment for the timer to reset
|
||||
await asyncio.sleep(0.1)
|
||||
async def test_cleanup_cancels_timer(self):
|
||||
"""Test that cleanup cancels a pending idle timer."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.3)
|
||||
await controller.cleanup()
|
||||
|
||||
# Third idle
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
third_count = idle_count
|
||||
self.assertGreater(third_count, second_count)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
async def test_function_call_cancels_timer(self):
|
||||
"""Test normal ordering: BotStopped starts timer, FunctionCallsStarted cancels it."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Bot finishes speaking, timer starts
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
# Function call starts shortly after, cancels the timer
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.3)
|
||||
await controller.process_frame(
|
||||
FunctionCallsStartedFrame(function_calls=[unittest.mock.Mock()])
|
||||
)
|
||||
|
||||
# Wait longer than timeout — should not fire
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_function_call_suppresses_timer(self):
|
||||
"""Test race condition: FunctionCallsStarted arrives before BotStopped.
|
||||
|
||||
A race condition can cause FunctionCallsStarted to arrive before
|
||||
BotStoppedSpeaking. The counter guard prevents the timer from starting
|
||||
while a function call is in progress.
|
||||
"""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# LLM emits function call and "let me check" concurrently
|
||||
await controller.process_frame(
|
||||
FunctionCallsStartedFrame(function_calls=[unittest.mock.Mock()])
|
||||
)
|
||||
await controller.process_frame(BotStartedSpeakingFrame())
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
|
||||
# Wait longer than timeout — should not fire (function call in progress)
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
# Function call completes, bot speaks result
|
||||
await controller.process_frame(
|
||||
FunctionCallResultFrame(
|
||||
function_name="test", tool_call_id="123", arguments={}, result="ok"
|
||||
)
|
||||
)
|
||||
await controller.process_frame(BotStartedSpeakingFrame())
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
|
||||
# Now the timer should start and fire
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
self.assertTrue(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_disabled_by_default(self):
|
||||
"""Test that timeout=0 means idle detection is disabled."""
|
||||
controller = UserIdleController()
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_enable_via_frame(self):
|
||||
"""Test enabling idle detection at runtime via UserIdleTimeoutUpdateFrame."""
|
||||
controller = UserIdleController()
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Initially disabled — no idle fires
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
# Enable idle detection
|
||||
await controller.process_frame(UserIdleTimeoutUpdateFrame(timeout=USER_IDLE_TIMEOUT))
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertTrue(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
async def test_disable_via_frame(self):
|
||||
"""Test disabling idle detection at runtime via UserIdleTimeoutUpdateFrame."""
|
||||
controller = UserIdleController(user_idle_timeout=USER_IDLE_TIMEOUT)
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
idle_triggered = False
|
||||
|
||||
@controller.event_handler("on_user_turn_idle")
|
||||
async def on_user_turn_idle(controller):
|
||||
nonlocal idle_triggered
|
||||
idle_triggered = True
|
||||
|
||||
# Start the timer
|
||||
await controller.process_frame(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT * 0.3)
|
||||
|
||||
# Disable — should cancel running timer
|
||||
await controller.process_frame(UserIdleTimeoutUpdateFrame(timeout=0))
|
||||
|
||||
await asyncio.sleep(USER_IDLE_TIMEOUT + 0.1)
|
||||
|
||||
self.assertFalse(idle_triggered)
|
||||
|
||||
await controller.cleanup()
|
||||
|
||||
|
||||
@@ -452,72 +452,6 @@ class TestSpeechTimeoutUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_finalized_transcript_does_not_trigger_early_with_slow_stt(self):
|
||||
"""Test that a finalized transcript arriving after user_speech_timeout
|
||||
but before the full timeout does not trigger immediately.
|
||||
|
||||
This reproduces a race condition where:
|
||||
- STT has high latency (effective_stt_wait > user_speech_timeout)
|
||||
- User pauses briefly, VAD fires stop
|
||||
- The full timeout = max(effective_stt_wait, user_speech_timeout)
|
||||
- The finalized transcript arrives after user_speech_timeout from VAD stop
|
||||
but before the full timeout
|
||||
- The user resumes speaking before the full timeout
|
||||
|
||||
Previously, the early trigger path would fire because
|
||||
time.time() - vad_stopped_time >= user_speech_timeout, even though the
|
||||
user was about to resume speaking.
|
||||
"""
|
||||
user_speech_timeout = 0.1
|
||||
strategy = SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=user_speech_timeout)
|
||||
await strategy.setup(self.task_manager)
|
||||
|
||||
# Set high STT P99 latency so effective_stt_wait > user_speech_timeout
|
||||
stt_timeout = 0.5
|
||||
stop_secs = 0.1
|
||||
await strategy.process_frame(
|
||||
STTMetadataFrame(service_name="test", ttfs_p99_latency=stt_timeout)
|
||||
)
|
||||
# effective_stt_wait = max(0, 0.5 - 0.1) = 0.4
|
||||
# timeout = max(0.4, 0.1) = 0.4
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
# S - user starts speaking
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
|
||||
# E - user pauses briefly
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
|
||||
# Wait for user_speech_timeout to elapse but NOT the full timeout
|
||||
await asyncio.sleep(user_speech_timeout + 0.05) # 0.15s elapsed
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Finalized transcript arrives (simulating slow STT).
|
||||
# At this point, elapsed from VAD stop (~0.15s) > user_speech_timeout (0.1s).
|
||||
# The old code would trigger immediately here.
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp="", finalized=True)
|
||||
)
|
||||
|
||||
# Should NOT trigger — the full timeout (0.4s) hasn't elapsed yet,
|
||||
# giving the user time to resume speaking
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# User resumes speaking — this cancels the timeout
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
|
||||
# Wait well past the original timeout
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Should still not have triggered — user resumed speaking
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
async def test_sie_delay_it(self):
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user