Compare commits
155 Commits
mb/remove-
...
v0.0.104
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5940731dd0 | ||
|
|
62260454a2 | ||
|
|
d1ad7a9580 | ||
|
|
252f17e1ca | ||
|
|
c79a739c85 | ||
|
|
038f6a77d1 | ||
|
|
5952ea711c | ||
|
|
aad1211a57 | ||
|
|
7dbb130666 | ||
|
|
c6c2c5ba05 | ||
|
|
141b0ee014 | ||
|
|
303616599f | ||
|
|
088eb9b01c | ||
|
|
32773b42d6 | ||
|
|
c039e08741 | ||
|
|
b449515410 | ||
|
|
aae9136df9 | ||
|
|
fdeddd7c95 | ||
|
|
11783520c0 | ||
|
|
49c73bb0a3 | ||
|
|
f07e55a4ed | ||
|
|
daf14f5065 | ||
|
|
ebb794995b | ||
|
|
5c2ca0ce64 | ||
|
|
6729f4366a | ||
|
|
7648b62e6e | ||
|
|
7afd7068b5 | ||
|
|
07fdd610ca | ||
|
|
a4796a2373 | ||
|
|
44466cfa07 | ||
|
|
741ff14d3a | ||
|
|
4a61d5bfad | ||
|
|
d0ecb3c7a8 | ||
|
|
8f66272de7 | ||
|
|
ff5b985009 | ||
|
|
a738a4d82b | ||
|
|
ddba1b84a9 | ||
|
|
18155b6a63 | ||
|
|
ac69b3441e | ||
|
|
98bd530574 | ||
|
|
b1e55fd6c2 | ||
|
|
dbdb54ce0f | ||
|
|
c1743dcffd | ||
|
|
389d0c3fb6 | ||
|
|
a88eae7849 | ||
|
|
0cfd953a90 | ||
|
|
bbbfdfd321 | ||
|
|
193f93c2ce | ||
|
|
75669b12a2 | ||
|
|
68e8732e72 | ||
|
|
de87894778 | ||
|
|
0836066898 | ||
|
|
58aa8e1ba5 | ||
|
|
670e5000d2 | ||
|
|
e6b9c5c4dc | ||
|
|
c54232bdb4 | ||
|
|
5a6a93e277 | ||
|
|
f386722ef9 | ||
|
|
7c07e090a4 | ||
|
|
07ba255073 | ||
|
|
eb7a4b7aee | ||
|
|
ad74d19c6b | ||
|
|
5e8d722bf2 | ||
|
|
a7f6db8436 | ||
|
|
442ea6a97e | ||
|
|
018ead8551 | ||
|
|
5e99aeedf5 | ||
|
|
c579749d8a | ||
|
|
094de42f0c | ||
|
|
1242f1c10e | ||
|
|
55a641e258 | ||
|
|
91c46ffbf4 | ||
|
|
024c62946f | ||
|
|
9b969736f6 | ||
|
|
6fc718947d | ||
|
|
cb7e612738 | ||
|
|
36b9c05730 | ||
|
|
6968d83ccb | ||
|
|
42f91a9056 | ||
|
|
5de495cc98 | ||
|
|
d1cbc81108 | ||
|
|
66fca7e382 | ||
|
|
07ae4b8d38 | ||
|
|
21a409e447 | ||
|
|
903dc6c1a9 | ||
|
|
dee94b3cb8 | ||
|
|
ece4343839 | ||
|
|
94a59de4e1 | ||
|
|
f37fd39cdb | ||
|
|
9d4955054c | ||
|
|
6464230627 | ||
|
|
950a8628dc | ||
|
|
17205c1647 | ||
|
|
2a776d0c1e | ||
|
|
d7ce1eedd9 | ||
|
|
ef00f27d53 | ||
|
|
56f2564ed1 | ||
|
|
000d38e253 | ||
|
|
36edef489e | ||
|
|
d077a810ae | ||
|
|
0839e3813f | ||
|
|
69414e8a5a | ||
|
|
dfd0a515f3 | ||
|
|
ed7f0a2c08 | ||
|
|
08d93ce9b6 | ||
|
|
f11d4b6944 | ||
|
|
51a3310e78 | ||
|
|
6f33aff0c6 | ||
|
|
45532a9478 | ||
|
|
4eb993c980 | ||
|
|
83e29eb478 | ||
|
|
6ba9f780b0 | ||
|
|
aa7e9a17d5 | ||
|
|
acff172bf2 | ||
|
|
9747e8da4a | ||
|
|
8fc63352d9 | ||
|
|
6ebfea4746 | ||
|
|
f74af9b9c7 | ||
|
|
82c249608f | ||
|
|
98e737b4e9 | ||
|
|
ec9ddb3199 | ||
|
|
712305c5b1 | ||
|
|
be8ea818c8 | ||
|
|
50710e9c3f | ||
|
|
a489bfaf00 | ||
|
|
945a523eed | ||
|
|
790c434a08 | ||
|
|
db40a354be | ||
|
|
aa6d3b38b3 | ||
|
|
41d6470e4a | ||
|
|
601822e3e5 | ||
|
|
3a32d91c66 | ||
|
|
35b3803ebc | ||
|
|
3b427a47b6 | ||
|
|
d701c3427c | ||
|
|
1f45e80f9d | ||
|
|
bc6f8e51de | ||
|
|
deba2515f9 | ||
|
|
127b52bad5 | ||
|
|
0697f72dae | ||
|
|
c259a6a73b | ||
|
|
3e04f5d05f | ||
|
|
cd07937c5d | ||
|
|
72934bd8ae | ||
|
|
2a6a993869 | ||
|
|
bbaa79fef0 | ||
|
|
fff9db0d8f | ||
|
|
b390dc369c | ||
|
|
a18aa738e0 | ||
|
|
9476b5d184 | ||
|
|
f49658de15 | ||
|
|
d38b1d97d4 | ||
|
|
0b4568843b | ||
|
|
35aba4128c | ||
|
|
5ea2d47d39 |
147
.github/workflows/update-docs.yml
vendored
Normal file
147
.github/workflows/update-docs.yml
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
name: Update Documentation on PR Merge
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/pipecat/services/**"
|
||||
- "src/pipecat/transports/**"
|
||||
- "src/pipecat/serializers/**"
|
||||
- "src/pipecat/processors/**"
|
||||
- "src/pipecat/audio/**"
|
||||
- "src/pipecat/turns/**"
|
||||
- "src/pipecat/observers/**"
|
||||
- "src/pipecat/pipeline/**"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: "PR number to generate docs for"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-docs:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Checkout pipecat
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Checkout docs
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: pipecat-ai/docs
|
||||
token: ${{ secrets.DOCS_SYNC_TOKEN }}
|
||||
path: _docs
|
||||
|
||||
- name: Resolve PR number
|
||||
id: pr
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
echo "number=${{ inputs.pr_number }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "number=${{ github.event.pull_request.number }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Update documentation
|
||||
uses: anthropics/claude-code-action@v1
|
||||
env:
|
||||
DOCS_SYNC_TOKEN: ${{ secrets.DOCS_SYNC_TOKEN }}
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
prompt: |
|
||||
You are updating documentation for the pipecat-ai/docs repository based on
|
||||
changes merged in PR #${{ steps.pr.outputs.number }} of pipecat-ai/pipecat.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Read the skill instructions at `.claude/skills/update-docs/SKILL.md`
|
||||
2. Read the source-to-doc mapping at `.claude/skills/update-docs/SOURCE_DOC_MAPPING.md`
|
||||
3. The docs repository is checked out at `./_docs/`
|
||||
|
||||
## Get the diff
|
||||
|
||||
Run `gh pr diff ${{ steps.pr.outputs.number }}` to see what changed in the PR.
|
||||
Also run `gh pr diff ${{ steps.pr.outputs.number }} --name-only` to get the list of changed files.
|
||||
Filter to source files matching the directories listed in SKILL.md Step 3.
|
||||
|
||||
If no relevant source files were changed, exit with "No documentation changes needed."
|
||||
|
||||
## Follow the skill instructions
|
||||
|
||||
Apply the SKILL.md workflow (Steps 3-9) with these adaptations for automation:
|
||||
|
||||
### Docs path
|
||||
Use `./_docs/` — it's already checked out. Do not ask for a path.
|
||||
|
||||
### Branch management
|
||||
- Branch name: `docs/pr-${{ steps.pr.outputs.number }}`
|
||||
- Work inside `./_docs/` for all doc edits and git operations
|
||||
- Check if the branch already exists on the remote:
|
||||
```bash
|
||||
cd _docs && git fetch origin docs/pr-${{ steps.pr.outputs.number }} 2>/dev/null
|
||||
```
|
||||
- If it exists: check it out (supports workflow re-runs)
|
||||
- If not: create it from main
|
||||
|
||||
### Git config
|
||||
Before committing in `_docs`, set:
|
||||
```bash
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
```
|
||||
|
||||
### No interactive questions
|
||||
Do not ask questions. If you encounter gaps (unmapped files, missing sections,
|
||||
ambiguous changes), note them in the PR body under "## Gaps identified".
|
||||
|
||||
### Creating the docs PR
|
||||
After committing all changes in `_docs`, push and create a PR:
|
||||
```bash
|
||||
cd _docs
|
||||
git push -u origin docs/pr-${{ steps.pr.outputs.number }}
|
||||
GH_TOKEN=$DOCS_SYNC_TOKEN gh pr create \
|
||||
--repo pipecat-ai/docs \
|
||||
--label auto-docs \
|
||||
--title "docs: update for pipecat PR #${{ steps.pr.outputs.number }}" \
|
||||
--body "$(cat <<'BODY'
|
||||
Automated documentation update for [pipecat PR #${{ steps.pr.outputs.number }}](https://github.com/pipecat-ai/pipecat/pull/${{ steps.pr.outputs.number }}).
|
||||
|
||||
## Changes
|
||||
<summarize each doc page updated and what changed>
|
||||
|
||||
## Gaps identified
|
||||
<any unmapped files, missing doc pages, or missing sections — or "None">
|
||||
BODY
|
||||
)"
|
||||
```
|
||||
|
||||
### Re-run handling
|
||||
If `gh pr create` fails because a PR from that branch already exists,
|
||||
push the updated commits and use `gh pr edit` to update the body instead.
|
||||
|
||||
### No-op
|
||||
If after analyzing the diff you determine no documentation changes are needed
|
||||
(e.g., only skip-listed files changed, or changes don't affect public API docs),
|
||||
exit cleanly without creating a branch or PR. Output "No documentation changes needed."
|
||||
|
||||
## Important rules
|
||||
- Only modify files inside `./_docs/` — never modify pipecat source code
|
||||
- Follow the conservative editing rules from SKILL.md Step 6
|
||||
- Read each doc page fully before editing (SKILL.md Guidelines)
|
||||
- Use `GH_TOKEN=$DOCS_SYNC_TOKEN` for all `gh` commands targeting pipecat-ai/docs
|
||||
claude_args: |
|
||||
--model claude-sonnet-4-5-20250929
|
||||
--max-turns 30
|
||||
--allowedTools "Read,Write,Edit,Glob,Grep,Bash"
|
||||
383
CHANGELOG.md
383
CHANGELOG.md
@@ -7,6 +7,389 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
<!-- towncrier release notes start -->
|
||||
|
||||
## [0.0.104] - 2026-03-02
|
||||
|
||||
### Added
|
||||
|
||||
- Added `TextAggregationMetricsData` metric measuring the time from the first
|
||||
LLM token to the first complete sentence, representing the latency cost of
|
||||
sentence aggregation in the TTS pipeline.
|
||||
(PR [#3696](https://github.com/pipecat-ai/pipecat/pull/3696))
|
||||
|
||||
- Added support for using strongly-typed objects instead of dicts for updating
|
||||
service settings at runtime.
|
||||
|
||||
Instead of, say:
|
||||
|
||||
```python
|
||||
await task.queue_frame(
|
||||
STTUpdateSettingsFrame(settings={"language": Language.ES})
|
||||
)
|
||||
```
|
||||
|
||||
you'd do:
|
||||
|
||||
```python
|
||||
await task.queue_frame(
|
||||
STTUpdateSettingsFrame(delta=DeepgramSTTSettings(language=Language.ES))
|
||||
)
|
||||
```
|
||||
|
||||
Each service now vends strongly-typed classes like `DeepgramSTTSettings`
|
||||
representing the service's runtime-updatable settings.
|
||||
(PR [#3714](https://github.com/pipecat-ai/pipecat/pull/3714))
|
||||
|
||||
- Added support for specifying private endpoints for Azure Speech-to-Text,
|
||||
enabling use in private networks behind firewalls.
|
||||
(PR [#3764](https://github.com/pipecat-ai/pipecat/pull/3764))
|
||||
|
||||
- Added `LemonSliceTransport` and `LemonSliceApi` to support adding real-time
|
||||
LemonSlice Avatars to any Daily room.
|
||||
(PR [#3791](https://github.com/pipecat-ai/pipecat/pull/3791))
|
||||
|
||||
- Added `output_medium` parameter to `AgentInputParams` and
|
||||
`OneShotInputParams` in Ultravox service to control initial output medium
|
||||
(text or voice) at call creation time.
|
||||
(PR [#3806](https://github.com/pipecat-ai/pipecat/pull/3806))
|
||||
|
||||
- Added `TurnMetricsData` as a generic metrics class for turn detection, with
|
||||
e2e processing time measurement. `KrispVivaTurn` now emits `TurnMetricsData`
|
||||
with `e2e_processing_time_ms` tracking the interval from VAD
|
||||
speech-to-silence transition to turn completion.
|
||||
(PR [#3809](https://github.com/pipecat-ai/pipecat/pull/3809))
|
||||
|
||||
- Added `on_audio_context_interrupted()` and `on_audio_context_completed()`
|
||||
callbacks to `AudioContextTTSService`. Subclasses can override these to
|
||||
perform provider-specific cleanup instead of overriding
|
||||
`_handle_interruption()`.
|
||||
(PR [#3814](https://github.com/pipecat-ai/pipecat/pull/3814))
|
||||
|
||||
- Added `on_summary_applied` event to `LLMContextSummarizer` for observability,
|
||||
providing message counts before and after context summarization.
|
||||
(PR [#3855](https://github.com/pipecat-ai/pipecat/pull/3855))
|
||||
|
||||
- Added `summary_message_template` to `LLMContextSummarizationConfig` for
|
||||
customizing how summaries are formatted when injected into context (e.g.,
|
||||
wrapping in XML tags).
|
||||
(PR [#3855](https://github.com/pipecat-ai/pipecat/pull/3855))
|
||||
|
||||
- Added `summarization_timeout` to `LLMContextSummarizationConfig` (default
|
||||
120s) to prevent hung LLM calls from permanently blocking future
|
||||
summarizations.
|
||||
(PR [#3855](https://github.com/pipecat-ai/pipecat/pull/3855))
|
||||
|
||||
- Added optional `llm` field to `LLMContextSummarizationConfig` for routing
|
||||
summarization to a dedicated LLM service (e.g., a cheaper/faster model)
|
||||
instead of the pipeline's primary model.
|
||||
(PR [#3855](https://github.com/pipecat-ai/pipecat/pull/3855))
|
||||
|
||||
- Add AssemblyAI u3-rt-pro model support with built-in turn detection mode
|
||||
(PR [#3856](https://github.com/pipecat-ai/pipecat/pull/3856))
|
||||
|
||||
- Added `LLMSummarizeContextFrame` to trigger on-demand context summarization
|
||||
from anywhere in the pipeline (e.g. a function call tool). Accepts an
|
||||
optional `config: LLMContextSummaryConfig` to override summary generation
|
||||
settings per request.
|
||||
(PR [#3863](https://github.com/pipecat-ai/pipecat/pull/3863))
|
||||
|
||||
- Added `LLMContextSummaryConfig` (summary generation params:
|
||||
`target_context_tokens`, `min_messages_after_summary`,
|
||||
`summarization_prompt`) and `LLMAutoContextSummarizationConfig` (auto-trigger
|
||||
thresholds: `max_context_tokens`, `max_unsummarized_messages`, plus a nested
|
||||
`summary_config`). These replace the monolithic
|
||||
`LLMContextSummarizationConfig`.
|
||||
(PR [#3863](https://github.com/pipecat-ai/pipecat/pull/3863))
|
||||
|
||||
- Added support for the `speed_alpha` parameter to the `arcana` model in
|
||||
`RimeTTSService`.
|
||||
(PR [#3873](https://github.com/pipecat-ai/pipecat/pull/3873))
|
||||
|
||||
- Added `ClientConnectedFrame`, a new `SystemFrame` pushed by all transports
|
||||
(Daily, LiveKit, FastAPI WebSocket, WebSocket Server, SmallWebRTC, HeyGen,
|
||||
Tavus) when a client connects. Enables observers to track transport readiness
|
||||
timing.
|
||||
(PR [#3881](https://github.com/pipecat-ai/pipecat/pull/3881))
|
||||
|
||||
- Added `StartupTimingObserver` for measuring how long each processor's
|
||||
`start()` method takes during pipeline startup. Also measures transport
|
||||
readiness — the time from `StartFrame` to first client connection — via the
|
||||
`on_transport_timing_report` event.
|
||||
(PR [#3881](https://github.com/pipecat-ai/pipecat/pull/3881))
|
||||
|
||||
- Added `BotConnectedFrame` for SFU transports and `on_transport_timing_report`
|
||||
event to `StartupTimingObserver` with bot and client connection timing.
|
||||
(PR [#3881](https://github.com/pipecat-ai/pipecat/pull/3881))
|
||||
|
||||
- Added optional `direction` parameter to `PipelineTask.queue_frame()` and
|
||||
`PipelineTask.queue_frames()`, allowing frames to be pushed upstream from the
|
||||
end of the pipeline.
|
||||
(PR [#3883](https://github.com/pipecat-ai/pipecat/pull/3883))
|
||||
|
||||
- Added `on_latency_breakdown` event to `UserBotLatencyObserver` providing
|
||||
per-service TTFB, text aggregation, user turn duration, and function call
|
||||
latency metrics for each user-to-bot response cycle.
|
||||
(PR [#3885](https://github.com/pipecat-ai/pipecat/pull/3885))
|
||||
|
||||
- Added `on_first_bot_speech_latency` event to `UserBotLatencyObserver`
|
||||
measuring the time from client connection to first bot speech. An
|
||||
`on_latency_breakdown` is also emitted for this first speech event.
|
||||
(PR [#3885](https://github.com/pipecat-ai/pipecat/pull/3885))
|
||||
|
||||
- Added `broadcast_interruption()` to `FrameProcessor`. This method pushes an
|
||||
`InterruptionFrame` both upstream and downstream directly from the calling
|
||||
processor, avoiding the round-trip through the pipeline task that
|
||||
`push_interruption_task_frame_and_wait()` required.
|
||||
(PR [#3896](https://github.com/pipecat-ai/pipecat/pull/3896))
|
||||
|
||||
### Changed
|
||||
|
||||
- Added `text_aggregation_mode` parameter to `TTSService` and all TTS
|
||||
subclasses with a new `TextAggregationMode` enum (`SENTENCE`, `TOKEN`). All
|
||||
text now flows through text aggregators regardless of mode, enabling pattern
|
||||
detection and tag handling in TOKEN mode.
|
||||
(PR [#3696](https://github.com/pipecat-ai/pipecat/pull/3696))
|
||||
|
||||
- ⚠️ Refactored runtime-updatable service settings to use strongly-typed
|
||||
classes (`TTSSettings`, `STTSettings`, `LLMSettings`, and service-specific
|
||||
subclasses) instead of plain dicts. Each service's `_settings` now holds
|
||||
these strongly-typed objects. For service maintainers, see changes in
|
||||
COMMUNITY_INTEGRATIONS.md.
|
||||
(PR [#3714](https://github.com/pipecat-ai/pipecat/pull/3714))
|
||||
|
||||
- Word timestamp support has been moved from `WordTTSService` into `TTSService`
|
||||
via a new `supports_word_timestamps` parameter. Services that previously
|
||||
extended `WordTTSService`, `AudioContextWordTTSService`, or
|
||||
`WebsocketWordTTSService` now pass `supports_word_timestamps=True` to their
|
||||
parent `__init__` instead.
|
||||
(PR [#3786](https://github.com/pipecat-ai/pipecat/pull/3786))
|
||||
|
||||
- Improved Ultravox TTFB measurement accuracy by using VAD speech end time
|
||||
instead of `UserStoppedSpeakingFrame` timing.
|
||||
(PR [#3806](https://github.com/pipecat-ai/pipecat/pull/3806))
|
||||
|
||||
- Aligned `UltravoxRealtimeLLMService` frame handling with OpenAI/Gemini
|
||||
realtime services: added `InterruptionFrame` handling with metrics cleanup,
|
||||
processing metrics at response boundaries, and improved agent transcript
|
||||
handling for both voice and text output modalities.
|
||||
(PR [#3806](https://github.com/pipecat-ai/pipecat/pull/3806))
|
||||
|
||||
- Updated `OpenAIRealtimeLLMService` default model to `gpt-realtime-1.5`.
|
||||
(PR [#3807](https://github.com/pipecat-ai/pipecat/pull/3807))
|
||||
|
||||
- Added `api_key` parameter to `KrispVivaSDKManager`, `KrispVivaTurn`, and
|
||||
`KrispVivaFilter` for Krisp SDK v1.6.1+ licensing. Falls back to
|
||||
`KRISP_VIVA_API_KEY` environment variable.
|
||||
(PR [#3809](https://github.com/pipecat-ai/pipecat/pull/3809))
|
||||
|
||||
- Bumped `nltk` minimum version from 3.9.1 to 3.9.3 to resolve a security
|
||||
vulnerability.
|
||||
(PR [#3811](https://github.com/pipecat-ai/pipecat/pull/3811))
|
||||
|
||||
- `ServiceSettingsUpdateFrame`s are now `UninterruptibleFrame`s. Generally
|
||||
speaking, you don't want a user interruption to prevent a service setting
|
||||
change from going into effect. Note that you usually don't use
|
||||
`ServiceSettingsUpdateFrame` directly, you use one of its subclasses:
|
||||
- `LLMUpdateSettingsFrame`
|
||||
- `TTSUpdateSettingsFrame`
|
||||
- `STTUpdateSettingsFrame`
|
||||
(PR [#3819](https://github.com/pipecat-ai/pipecat/pull/3819))
|
||||
|
||||
- Updated context summarization to use `user` role instead of `assistant` for
|
||||
summary messages.
|
||||
(PR [#3855](https://github.com/pipecat-ai/pipecat/pull/3855))
|
||||
|
||||
- Rename `AssemblyAISTTService` parameter
|
||||
`min_end_of_turn_silence_when_confident` parameter to `min_turn_silence` (old
|
||||
name still supported with deprecation warning)
|
||||
(PR [#3856](https://github.com/pipecat-ai/pipecat/pull/3856))
|
||||
|
||||
- ⚠️ Renamed `LLMAssistantAggregatorParams` fields:
|
||||
`enable_context_summarization` → `enable_auto_context_summarization` and
|
||||
`context_summarization_config` → `auto_context_summarization_config` (now
|
||||
accepts `LLMAutoContextSummarizationConfig`). The old names still work with a
|
||||
`DeprecationWarning` for one release cycle.
|
||||
(PR [#3863](https://github.com/pipecat-ai/pipecat/pull/3863))
|
||||
|
||||
- `ElevenLabsRealtimeSTTService` now sets `TranscriptionFrame.finalized` to
|
||||
`True` when using `CommitStrategy.MANUAL`.
|
||||
(PR [#3865](https://github.com/pipecat-ai/pipecat/pull/3865))
|
||||
|
||||
- Updated numba version pin from == to >=0.61.2
|
||||
(PR [#3868](https://github.com/pipecat-ai/pipecat/pull/3868))
|
||||
|
||||
- Updated tracing code to use `ServiceSettings` dataclass API
|
||||
(`given_fields()`, attribute access) instead of dict-style access
|
||||
(`.items()`, `in`, subscript).
|
||||
(PR [#3879](https://github.com/pipecat-ai/pipecat/pull/3879))
|
||||
|
||||
- ⚠️ Removed `event` field and `complete()` method from `InterruptionFrame`.
|
||||
Removed `event` field from `InterruptionTaskFrame`. These are no longer
|
||||
needed since `broadcast_interruption()` does not require a round-trip
|
||||
completion signal.
|
||||
(PR [#3896](https://github.com/pipecat-ai/pipecat/pull/3896))
|
||||
|
||||
- Moved `pipecat.services.deepgram.stt_sagemaker` and
|
||||
`pipecat.services.deepgram.tts_sagemaker` to
|
||||
`pipecat.services.deepgram.sagemaker.stt` and
|
||||
`pipecat.services.deepgram.sagemaker.tts`. The old import paths still work
|
||||
but emit a `DeprecationWarning`.
|
||||
(PR [#3902](https://github.com/pipecat-ai/pipecat/pull/3902))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- ⚠️ Deprecated `aggregate_sentences` parameter on `TTSService` and all TTS
|
||||
subclasses. Use `text_aggregation_mode=TextAggregationMode.SENTENCE` or
|
||||
`text_aggregation_mode=TextAggregationMode.TOKEN` instead.
|
||||
(PR [#3696](https://github.com/pipecat-ai/pipecat/pull/3696))
|
||||
|
||||
- Deprecated `set_model()`, `set_voice()`, and `set_language()` on AI services
|
||||
in favor of runtime updates via `TTSUpdateSettingsFrame`,
|
||||
`STTUpdateSettingsFrame`, and `LLMUpdateSettingsFrame`.
|
||||
|
||||
⚠️ Note, too, a subtle behavior change in these deprecated methods. Whereas
|
||||
previously only `set_language()` caused the service to actually react to the
|
||||
update (e.g. by reconnecting to a remote service so it an pick up the
|
||||
change), now all these methods do. This change was made as part of a refactor
|
||||
making them all work the same way under the hood.
|
||||
(PR [#3714](https://github.com/pipecat-ai/pipecat/pull/3714))
|
||||
|
||||
- Dict-based `*UpdateSettingsFrame(settings={...})` is deprecated in favor of
|
||||
passing typed settings delta objects with
|
||||
`*UpdateSettingsFrame(delta={...})`.
|
||||
(PR [#3714](https://github.com/pipecat-ai/pipecat/pull/3714))
|
||||
|
||||
- Deprecated `WordTTSService`, `WebsocketWordTTSService`,
|
||||
`AudioContextWordTTSService`, and `InterruptibleWordTTSService`. Use their
|
||||
non-word counterparts with `supports_word_timestamps=True` instead:
|
||||
- `WordTTSService` → `TTSService(supports_word_timestamps=True)`
|
||||
- `WebsocketWordTTSService` →
|
||||
`WebsocketTTSService(supports_word_timestamps=True)`
|
||||
- `AudioContextWordTTSService` →
|
||||
`AudioContextTTSService(supports_word_timestamps=True)`
|
||||
- `InterruptibleWordTTSService` →
|
||||
`InterruptibleTTSService(supports_word_timestamps=True)`
|
||||
(PR [#3786](https://github.com/pipecat-ai/pipecat/pull/3786))
|
||||
|
||||
- Deprecated `SmartTurnMetricsData` in favor of `TurnMetricsData`.
|
||||
`BaseSmartTurn` now emits `TurnMetricsData` directly.
|
||||
(PR [#3809](https://github.com/pipecat-ai/pipecat/pull/3809))
|
||||
|
||||
- Deprecated `LLMContextSummarizationConfig`. Use
|
||||
`LLMAutoContextSummarizationConfig` with a nested `LLMContextSummaryConfig`
|
||||
instead. The old class emits a `DeprecationWarning`.
|
||||
(PR [#3863](https://github.com/pipecat-ai/pipecat/pull/3863))
|
||||
|
||||
- Deprecated `push_interruption_task_frame_and_wait()` in `FrameProcessor`. Use
|
||||
`broadcast_interruption()` instead. The old method now delegates to
|
||||
`broadcast_interruption()` and logs a deprecation warning.
|
||||
(PR [#3896](https://github.com/pipecat-ai/pipecat/pull/3896))
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed `local-smart-turn-v3` optional extra from `pyproject.toml`. The
|
||||
`transformers` and `onnxruntime` packages are now always installed as core
|
||||
dependencies since they are required by the default turn stop strategy,
|
||||
`TurnAnalyzerUserTurnStopStrategy` which uses `LocalSmartTurnAnalyzerV3`.
|
||||
(PR [#3803](https://github.com/pipecat-ai/pipecat/pull/3803))
|
||||
|
||||
- ⚠️ Removed `PlayHTTTSService` and `PlayHTHttpTTSService`. PlayHT has been
|
||||
shut down and is no longer available.
|
||||
(PR [#3838](https://github.com/pipecat-ai/pipecat/pull/3838))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Added `LLMSpecificMessage` handling in `LLMContextSummarizationUtil` to skip
|
||||
provider-specific messages during context summarization.
|
||||
(PR [#3794](https://github.com/pipecat-ai/pipecat/pull/3794))
|
||||
|
||||
- Treated `response_cancel_not_active` as a non-fatal error in realtime
|
||||
services (`OpenAIRealtimeLLMService`, `GrokRealtimeLLMService`,
|
||||
`OpenAIRealtimeBetaLLMService`) to prevent WebSocket disconnection when
|
||||
cancelling an inactive response.
|
||||
(PR [#3795](https://github.com/pipecat-ai/pipecat/pull/3795))
|
||||
|
||||
- Fixed Poetry compatibility by inlining `local-smart-turn-v3` dependencies
|
||||
(`transformers`, `onnxruntime`) into core dependencies instead of using a
|
||||
self-referential extra.
|
||||
(PR [#3803](https://github.com/pipecat-ai/pipecat/pull/3803))
|
||||
|
||||
- Fixed `SentryMetrics` method signatures to match updated
|
||||
`FrameProcessorMetrics` base class, resolving `TypeError` when using
|
||||
`start_time`/`end_time` keyword arguments.
|
||||
(PR [#3808](https://github.com/pipecat-ai/pipecat/pull/3808))
|
||||
|
||||
- Fixed STT TTFB metrics not being reported for `SonioxSTTService` and
|
||||
`AWSTranscribeSTTService` due to missing `can_generate_metrics()` override.
|
||||
(PR [#3813](https://github.com/pipecat-ai/pipecat/pull/3813))
|
||||
|
||||
- Fixed an issue where `AudioContextTTSService`-based providers (AsyncAI,
|
||||
ElevenLabs, Inworld, Rime) did not close or clean up their server-side audio
|
||||
contexts after normal speech completion, only on interruption.
|
||||
(PR [#3814](https://github.com/pipecat-ai/pipecat/pull/3814))
|
||||
|
||||
- Fixed STT TTFB metrics measuring timeout expiry time instead of actual
|
||||
transcript arrival time.
|
||||
(PR [#3822](https://github.com/pipecat-ai/pipecat/pull/3822))
|
||||
|
||||
- Fixed `InterimTranscriptionFrame` and `TranslationFrame` being
|
||||
unintentionally pushed downstream in `LLMUserAggregator`. They are now
|
||||
consumed like `TranscriptionFrame`.
|
||||
(PR [#3825](https://github.com/pipecat-ai/pipecat/pull/3825))
|
||||
|
||||
- Fixed misleading "Empty audio frame received for STT service" warnings when
|
||||
using audio filters (e.g. `RNNoiseFilter`, `KrispVivaFilter`, `AICFilter`)
|
||||
that buffer audio internally.
|
||||
(PR [#3828](https://github.com/pipecat-ai/pipecat/pull/3828))
|
||||
|
||||
- Fixed issues with `RimeNonJsonTTSService` where trailing punctuation is
|
||||
sometimes vocalized
|
||||
(PR [#3837](https://github.com/pipecat-ai/pipecat/pull/3837))
|
||||
|
||||
- Fixed `TTSSpeakFrame` not committing spoken text to the conversation context
|
||||
when used outside of an LLM response (e.g., bot greetings or injected
|
||||
speech).
|
||||
(PR [#3845](https://github.com/pipecat-ai/pipecat/pull/3845))
|
||||
|
||||
- Removed verbose per-chunk audio logging from `GenesysAudioHookSerializer`
|
||||
that flooded production logs.
|
||||
(PR [#3850](https://github.com/pipecat-ai/pipecat/pull/3850))
|
||||
|
||||
- Add beta feature warning when using custom prompts with AssemblyAI
|
||||
(PR [#3856](https://github.com/pipecat-ai/pipecat/pull/3856))
|
||||
|
||||
- Fixed `LocalSmartTurnAnalyzerV3` producing incorrect end-of-turn predictions
|
||||
at non-16kHz sample rates (e.g. 8kHz Twilio telephony) by adding automatic
|
||||
resampling to 16kHz before Whisper feature extraction.
|
||||
(PR [#3857](https://github.com/pipecat-ai/pipecat/pull/3857))
|
||||
|
||||
- Fixed `PipelineTask` double-inserting `RTVIProcessor` into the frame chain
|
||||
when the user provides both an `RTVIProcessor` in the pipeline and a custom
|
||||
`RTVIObserver` subclass in observers.
|
||||
(PR [#3867](https://github.com/pipecat-ai/pipecat/pull/3867))
|
||||
|
||||
- Fixed turn completion instructions being lost when `LLMMessagesUpdateFrame`
|
||||
replaces the LLM context. When `filter_incomplete_user_turns` is enabled, the
|
||||
turn completion system message is now re-injected after context replacement.
|
||||
(PR [#3888](https://github.com/pipecat-ai/pipecat/pull/3888))
|
||||
|
||||
- Fixed Azure TTS and STT services silently swallowing cancellation errors
|
||||
(invalid API key, network failures, rate limiting) instead of propagating
|
||||
them as `ErrorFrame`s to the pipeline.
|
||||
(PR [#3893](https://github.com/pipecat-ai/pipecat/pull/3893))
|
||||
|
||||
### Performance
|
||||
|
||||
- Switched `GradiumTTSService` from `InterruptibleWordTTSService` to
|
||||
`AudioContextWordTTSService`, eliminating websocket disconnect/reconnect on
|
||||
every interruption by using `client_req_id`-based multiplexing.
|
||||
(PR [#3759](https://github.com/pipecat-ai/pipecat/pull/3759))
|
||||
|
||||
### Other
|
||||
|
||||
- Standardized Sarvam STT/TTS User-Agent header handling to consistently send
|
||||
Pipecat SDK identity in websocket requests.
|
||||
(PR [#3886](https://github.com/pipecat-ai/pipecat/pull/3886))
|
||||
|
||||
## [0.0.103] - 2026-02-20
|
||||
|
||||
### Added
|
||||
|
||||
@@ -89,7 +89,7 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [Grok Voice Agent](https://docs.pipecat.ai/server/services/s2s/grok), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai), [Ultravox](https://docs.pipecat.ai/server/services/s2s/ultravox), |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Exotel](https://docs.pipecat.ai/server/utilities/serializers/exotel), [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx), [Vonage](https://docs.pipecat.ai/server/utilities/serializers/vonage) |
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [LemonSlice](https://docs.pipecat.ai/server/services/video/lemonslice), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/google-imagen), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter) |
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
- Added `TextAggregationMetricsData` metric measuring the time from the first LLM token to the first complete sentence, representing the latency cost of sentence aggregation in the TTS pipeline.
|
||||
@@ -1 +0,0 @@
|
||||
- Added `text_aggregation_mode` parameter to `TTSService` and all TTS subclasses with a new `TextAggregationMode` enum (`SENTENCE`, `TOKEN`). All text now flows through text aggregators regardless of mode, enabling pattern detection and tag handling in TOKEN mode.
|
||||
@@ -1 +0,0 @@
|
||||
- ⚠️ Deprecated `aggregate_sentences` parameter on `TTSService` and all TTS subclasses. Use `text_aggregation_mode=TextAggregationMode.SENTENCE` or `text_aggregation_mode=TextAggregationMode.TOKEN` instead.
|
||||
@@ -1,19 +0,0 @@
|
||||
- Added support for using strongly-typed objects instead of dicts for updating service settings at runtime.
|
||||
|
||||
Instead of, say:
|
||||
|
||||
```python
|
||||
await task.queue_frame(
|
||||
STTUpdateSettingsFrame(settings={"language": Language.ES})
|
||||
)
|
||||
```
|
||||
|
||||
you'd do:
|
||||
|
||||
```python
|
||||
await task.queue_frame(
|
||||
STTUpdateSettingsFrame(delta=DeepgramSTTSettings(language=Language.ES))
|
||||
)
|
||||
```
|
||||
|
||||
Each service now vends strongly-typed classes like `DeepgramSTTSettings` representing the service's runtime-updatable settings.
|
||||
@@ -1 +0,0 @@
|
||||
- ⚠️ Refactored runtime-updatable service settings to use strongly-typed classes (`TTSSettings`, `STTSettings`, `LLMSettings`, and service-specific subclasses) instead of plain dicts. Each service's `_settings` now holds these strongly-typed objects. For service maintainers, see changes in COMMUNITY_INTEGRATIONS.md.
|
||||
@@ -1 +0,0 @@
|
||||
- Dict-based `*UpdateSettingsFrame(settings={...})` is deprecated in favor of passing typed settings delta objects with `*UpdateSettingsFrame(delta={...})`.
|
||||
@@ -1,3 +0,0 @@
|
||||
- Deprecated `set_model()`, `set_voice()`, and `set_language()` on AI services in favor of runtime updates via `TTSUpdateSettingsFrame`, `STTUpdateSettingsFrame`, and `LLMUpdateSettingsFrame`.
|
||||
|
||||
⚠️ Note, too, a subtle behavior change in these deprecated methods. Whereas previously only `set_language()` caused the service to actually react to the update (e.g. by reconnecting to a remote service so it an pick up the change), now all these methods do. This change was made as part of a refactor making them all work the same way under the hood.
|
||||
@@ -1 +0,0 @@
|
||||
- Switched `GradiumTTSService` from `InterruptibleWordTTSService` to `AudioContextWordTTSService`, eliminating websocket disconnect/reconnect on every interruption by using `client_req_id`-based multiplexing.
|
||||
@@ -1 +0,0 @@
|
||||
- Word timestamp support has been moved from `WordTTSService` into `TTSService` via a new `supports_word_timestamps` parameter. Services that previously extended `WordTTSService`, `AudioContextWordTTSService`, or `WebsocketWordTTSService` now pass `supports_word_timestamps=True` to their parent `__init__` instead.
|
||||
@@ -1,5 +0,0 @@
|
||||
- Deprecated `WordTTSService`, `WebsocketWordTTSService`, `AudioContextWordTTSService`, and `InterruptibleWordTTSService`. Use their non-word counterparts with `supports_word_timestamps=True` instead:
|
||||
- `WordTTSService` → `TTSService(supports_word_timestamps=True)`
|
||||
- `WebsocketWordTTSService` → `WebsocketTTSService(supports_word_timestamps=True)`
|
||||
- `AudioContextWordTTSService` → `AudioContextTTSService(supports_word_timestamps=True)`
|
||||
- `InterruptibleWordTTSService` → `InterruptibleTTSService(supports_word_timestamps=True)`
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed Poetry compatibility by inlining `local-smart-turn-v3` dependencies (`transformers`, `onnxruntime`) into core dependencies instead of using a self-referential extra.
|
||||
@@ -1 +0,0 @@
|
||||
- Removed `local-smart-turn-v3` optional extra from `pyproject.toml`. The `transformers` and `onnxruntime` packages are now always installed as core dependencies since they are required by the default turn stop strategy, `TurnAnalyzerUserTurnStopStrategy` which uses `LocalSmartTurnAnalyzerV3`.
|
||||
@@ -1 +0,0 @@
|
||||
- Added `output_medium` parameter to `AgentInputParams` and `OneShotInputParams` in Ultravox service to control initial output medium (text or voice) at call creation time.
|
||||
@@ -1 +0,0 @@
|
||||
- Improved Ultravox TTFB measurement accuracy by using VAD speech end time instead of `UserStoppedSpeakingFrame` timing.
|
||||
@@ -1 +0,0 @@
|
||||
- Aligned `UltravoxRealtimeLLMService` frame handling with OpenAI/Gemini realtime services: added `InterruptionFrame` handling with metrics cleanup, processing metrics at response boundaries, and improved agent transcript handling for both voice and text output modalities.
|
||||
@@ -1 +0,0 @@
|
||||
- Updated `OpenAIRealtimeLLMService` default model to `gpt-realtime-1.5`.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed `SentryMetrics` method signatures to match updated `FrameProcessorMetrics` base class, resolving `TypeError` when using `start_time`/`end_time` keyword arguments.
|
||||
@@ -1 +0,0 @@
|
||||
- Added `TurnMetricsData` as a generic metrics class for turn detection, with e2e processing time measurement. `KrispVivaTurn` now emits `TurnMetricsData` with `e2e_processing_time_ms` tracking the interval from VAD speech-to-silence transition to turn completion.
|
||||
@@ -1 +0,0 @@
|
||||
- Added `api_key` parameter to `KrispVivaSDKManager`, `KrispVivaTurn`, and `KrispVivaFilter` for Krisp SDK v1.6.1+ licensing. Falls back to `KRISP_VIVA_API_KEY` environment variable.
|
||||
@@ -1 +0,0 @@
|
||||
- Deprecated `SmartTurnMetricsData` in favor of `TurnMetricsData`. `BaseSmartTurn` now emits `TurnMetricsData` directly.
|
||||
@@ -1 +0,0 @@
|
||||
- Bumped `nltk` minimum version from 3.9.1 to 3.9.3 to resolve a security vulnerability.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed STT TTFB metrics not being reported for `SonioxSTTService` and `AWSTranscribeSTTService` due to missing `can_generate_metrics()` override.
|
||||
@@ -1 +0,0 @@
|
||||
- Added `on_audio_context_interrupted()` and `on_audio_context_completed()` callbacks to `AudioContextTTSService`. Subclasses can override these to perform provider-specific cleanup instead of overriding `_handle_interruption()`.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed an issue where `AudioContextTTSService`-based providers (AsyncAI, ElevenLabs, Inworld, Rime) did not close or clean up their server-side audio contexts after normal speech completion, only on interruption.
|
||||
@@ -1,4 +0,0 @@
|
||||
- `ServiceSettingsUpdateFrame`s are now `UninterruptibleFrame`s. Generally speaking, you don't want a user interruption to prevent a service setting change from going into effect. Note that you usually don't use `ServiceSettingsUpdateFrame` directly, you use one of its subclasses:
|
||||
- `LLMUpdateSettingsFrame`
|
||||
- `TTSUpdateSettingsFrame`
|
||||
- `STTUpdateSettingsFrame`
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed STT TTFB metrics measuring timeout expiry time instead of actual transcript arrival time.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed `InterimTranscriptionFrame` and `TranslationFrame` being unintentionally pushed downstream in `LLMUserAggregator`. They are now consumed like `TranscriptionFrame`.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed misleading "Empty audio frame received for STT service" warnings when using audio filters (e.g. `RNNoiseFilter`, `KrispVivaFilter`, `AICFilter`) that buffer audio internally.
|
||||
@@ -1 +0,0 @@
|
||||
- Fixed issues with `RimeNonJsonTTSService` where trailing punctuation is sometimes vocalized
|
||||
@@ -1 +0,0 @@
|
||||
- ⚠️ Removed `PlayHTTTSService` and `PlayHTHttpTTSService`. PlayHT has been shut down and is no longer available.
|
||||
@@ -108,6 +108,10 @@ KRISP_VIVA_API_KEY=...
|
||||
KRISP_VIVA_FILTER_MODEL_PATH=...
|
||||
KRISP_VIVA_TURN_MODEL_PATH=...
|
||||
|
||||
# LemonSlice
|
||||
LEMONSLICE_API_KEY=...
|
||||
LEMONSLICE_AGENT_ID=...
|
||||
|
||||
# LiveKit
|
||||
LIVEKIT_API_KEY=...
|
||||
LIVEKIT_API_SECRET=...
|
||||
|
||||
@@ -10,6 +10,7 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -72,7 +73,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context = LLMContext(messages)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(user_turn_strategies=ExternalUserTurnStrategies()),
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=ExternalUserTurnStrategies(),
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
|
||||
@@ -23,8 +23,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.tts_sagemaker import DeepgramSageMakerTTSService
|
||||
from pipecat.services.deepgram.sagemaker.stt import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.sagemaker.tts import DeepgramSageMakerTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.assemblyai.models import AssemblyAIConnectionParams
|
||||
from pipecat.services.assemblyai.stt import AssemblyAISTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
"""AssemblyAI u3-rt-pro with Built-in Turn Detection
|
||||
|
||||
This example demonstrates using AssemblyAI's u3-rt-pro Speech-to-Text model
|
||||
with AssemblyAI's built-in turn detection for more natural conversation flow.
|
||||
|
||||
Key features:
|
||||
|
||||
1. AssemblyAI Turn Detection
|
||||
- Set `vad_force_turn_endpoint=False` to use AssemblyAI's built-in turn detection
|
||||
- AssemblyAI's model determines when user starts/stops speaking
|
||||
- Uses `ExternalUserTurnStrategies` to delegate turn control to AssemblyAI
|
||||
- More natural turn detection based on speech patterns and pauses
|
||||
|
||||
2. Advanced Turn Detection Tuning
|
||||
- `min_turn_silence`: Minimum silence (ms) when confident about end-of-turn.
|
||||
Lower values = faster responses. Default: 100ms
|
||||
- `max_turn_silence`: Maximum silence (ms) before forcing end-of-turn.
|
||||
Prevents long pauses. Default: 1000ms
|
||||
|
||||
3. Prompt-Based Transcription Enhancement
|
||||
- Use `prompt` parameter to improve accuracy for specific names/terms
|
||||
- Particularly useful for proper nouns, technical terms, domain vocabulary
|
||||
- Example: "Names: Xiomara, Saoirse, Krzystof. Technical terms: API, OAuth."
|
||||
|
||||
4. Speaker Diarization (Optional)
|
||||
- Enable with `speaker_labels=True`
|
||||
- Automatically identifies different speakers in multi-party conversations
|
||||
- TranscriptionFrame includes speaker_id field (e.g., "Speaker A", "Speaker B")
|
||||
|
||||
5. Language Detection (Optional, multilingual model only)
|
||||
- Enable with `language_detection=True`
|
||||
- Automatically detects spoken language
|
||||
- Available with universal-streaming-multilingual model
|
||||
|
||||
For more information: https://www.assemblyai.com/docs/speech-to-text/streaming
|
||||
"""
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = AssemblyAISTTService(
|
||||
api_key=os.getenv("ASSEMBLYAI_API_KEY"),
|
||||
vad_force_turn_endpoint=False, # Use AssemblyAI's built-in turn detection
|
||||
connection_params=AssemblyAIConnectionParams(
|
||||
speech_model="u3-rt-pro",
|
||||
# Optional: Tune turn detection timing (defaults shown below)
|
||||
# min_turn_silence=100, # Default
|
||||
# max_turn_silence=1000, # Default
|
||||
# Optional: Boost accuracy for specific names/terms
|
||||
# prompt="Names: Xiomara, Saoirse, Krzystof. Technical terms: API, OAuth.",
|
||||
# Optional: Enable speaker diarization
|
||||
# speaker_labels=True,
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=ExternalUserTurnStrategies(),
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -55,7 +55,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
|
||||
api_key=os.getenv("NVIDIA_API_KEY"),
|
||||
model="meta/llama-3.3-70b-instruct",
|
||||
)
|
||||
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
@@ -16,6 +16,7 @@ from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.assemblyai.models import AssemblyAIConnectionParams
|
||||
from pipecat.services.assemblyai.stt import AssemblyAISTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -49,6 +50,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
stt = AssemblyAISTTService(
|
||||
api_key=os.getenv("ASSEMBLYAI_API_KEY"),
|
||||
connection_params=AssemblyAIConnectionParams(
|
||||
speech_model="u3-rt-pro",
|
||||
),
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
@@ -12,12 +12,15 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
@@ -42,20 +45,14 @@ transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -104,17 +101,20 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -5,13 +5,17 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.observers.startup_timing_observer import StartupTimingObserver
|
||||
from pipecat.observers.user_bot_latency_observer import UserBotLatencyObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -25,6 +29,7 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
@@ -32,6 +37,17 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await asyncio.sleep(0.25)
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await asyncio.sleep(0.1)
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
@@ -62,6 +78,38 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -69,7 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context = LLMContext(messages, tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
@@ -87,8 +135,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
]
|
||||
)
|
||||
|
||||
# Create latency tracking observer
|
||||
latency_observer = UserBotLatencyObserver()
|
||||
startup_observer = StartupTimingObserver()
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
@@ -97,14 +145,29 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[latency_observer],
|
||||
observers=[latency_observer, startup_observer],
|
||||
)
|
||||
|
||||
# Log latency measurements using the event handler
|
||||
@latency_observer.event_handler("on_first_bot_speech_latency")
|
||||
async def on_first_bot_speech_latency(observer, latency_seconds):
|
||||
logger.info(f"First bot speech: {latency_seconds:.3f}s after client connected")
|
||||
|
||||
@latency_observer.event_handler("on_latency_measured")
|
||||
async def on_latency_measured(observer, latency_seconds):
|
||||
logger.info(f"⏱️ User-to-bot latency: {latency_seconds:.3f}s")
|
||||
|
||||
@startup_observer.event_handler("on_startup_timing_report")
|
||||
async def on_startup_timing_report(observer, report):
|
||||
logger.info(f"Total startup: {report.total_duration_secs:.3f}s")
|
||||
for timing in report.processor_timings:
|
||||
logger.info(f" {timing.processor_name}: {timing.duration_secs:.3f}s")
|
||||
|
||||
@startup_observer.event_handler("on_transport_timing_report")
|
||||
async def on_transport_timing_report(observer, report):
|
||||
if report.bot_connected_secs is not None:
|
||||
logger.info(f"Bot connected: {report.bot_connected_secs:.3f}s")
|
||||
logger.info(f"Client connected: {report.client_connected_secs:.3f}s")
|
||||
|
||||
turn_observer = task.turn_tracking_observer
|
||||
if turn_observer:
|
||||
|
||||
@@ -119,6 +182,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
else:
|
||||
logger.info(f"🏁 Turn {turn_number} completed in {duration:.2f}s")
|
||||
|
||||
@latency_observer.event_handler("on_latency_breakdown")
|
||||
async def on_latency_breakdown(observer, breakdown):
|
||||
for event in breakdown.chronological_events():
|
||||
logger.info(f" {event}")
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
|
||||
@@ -11,6 +11,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -110,6 +111,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
await task.queue_frames(
|
||||
[
|
||||
TTSSpeakFrame(
|
||||
text="Hello, welcome to live translation. Everything you say will be automatically translated to Spanish. Let's begin!",
|
||||
append_to_context=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -20,14 +20,13 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
@@ -42,9 +41,10 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMAutoContextSummarizationConfig,
|
||||
LLMContextSummaryConfig,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -120,24 +120,36 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
assistant_params=LLMAssistantAggregatorParams(
|
||||
enable_context_summarization=True,
|
||||
enable_auto_context_summarization=True,
|
||||
# Optional: customize context summarization behavior
|
||||
# Using low limits to demonstrate the feature quickly
|
||||
context_summarization_config=LLMContextSummarizationConfig(
|
||||
auto_context_summarization_config=LLMAutoContextSummarizationConfig(
|
||||
max_context_tokens=1000, # Trigger summarization at 1000 tokens
|
||||
target_context_tokens=800, # Target context size for the summarization
|
||||
max_unsummarized_messages=10, # Or when 10 new messages accumulate
|
||||
min_messages_after_summary=2, # Keep last 2 messages uncompressed
|
||||
summary_config=LLMContextSummaryConfig(
|
||||
target_context_tokens=800, # Target context size for the summarization
|
||||
min_messages_after_summary=2, # Keep last 2 messages uncompressed
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for summarization events
|
||||
summarizer = assistant_aggregator._summarizer
|
||||
if summarizer:
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(
|
||||
f"Context summarized: {event.original_message_count} messages -> "
|
||||
f"{event.new_message_count} messages "
|
||||
f"({event.summarized_message_count} summarized, "
|
||||
f"{event.preserved_message_count} preserved)"
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
|
||||
@@ -20,14 +20,13 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
@@ -42,9 +41,10 @@ from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMAutoContextSummarizationConfig,
|
||||
LLMContextSummaryConfig,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -120,24 +120,36 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
assistant_params=LLMAssistantAggregatorParams(
|
||||
enable_context_summarization=True,
|
||||
enable_auto_context_summarization=True,
|
||||
# Optional: customize context summarization behavior
|
||||
# Using low limits to demonstrate the feature quickly
|
||||
context_summarization_config=LLMContextSummarizationConfig(
|
||||
auto_context_summarization_config=LLMAutoContextSummarizationConfig(
|
||||
max_context_tokens=1000, # Trigger summarization at 1000 tokens
|
||||
target_context_tokens=800, # Target context size for the summarization
|
||||
max_unsummarized_messages=10, # Or when 10 new messages accumulate
|
||||
min_messages_after_summary=2, # Keep last 2 messages uncompressed
|
||||
summary_config=LLMContextSummaryConfig(
|
||||
target_context_tokens=800, # Target context size for the summarization
|
||||
min_messages_after_summary=2, # Keep last 2 messages uncompressed
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for summarization events
|
||||
summarizer = assistant_aggregator._summarizer
|
||||
if summarizer:
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(
|
||||
f"Context summarized: {event.original_message_count} messages -> "
|
||||
f"{event.new_message_count} messages "
|
||||
f"({event.summarized_message_count} summarized, "
|
||||
f"{event.preserved_message_count} preserved)"
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
|
||||
172
examples/foundational/54b-context-summarization-manual-openai.py
Normal file
172
examples/foundational/54b-context-summarization-manual-openai.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example demonstrating manual context summarization via a function call.
|
||||
|
||||
This example shows how to trigger context summarization on demand rather than
|
||||
automatically. The user can ask the bot to "summarize the conversation" and the
|
||||
bot will call a function that pushes an LLMSummarizeContextFrame into the
|
||||
pipeline, causing the LLM service to compress the conversation history.
|
||||
|
||||
Unlike example 54, automatic summarization is NOT enabled here. Summarization
|
||||
only happens when the user explicitly requests it through the function call.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSummarizeContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def summarize_conversation(params: FunctionCallParams):
|
||||
"""Trigger manual context summarization via a pipeline frame."""
|
||||
logger.info("Tool called: summarize_conversation")
|
||||
await params.result_callback({"status": "summarization_requested"})
|
||||
await params.llm.queue_frame(LLMSummarizeContextFrame())
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
llm.register_function("summarize_conversation", summarize_conversation)
|
||||
|
||||
summarize_function = FunctionSchema(
|
||||
name="summarize_conversation",
|
||||
description=(
|
||||
"Summarize and compress the conversation history. "
|
||||
"Call this when the user asks you to summarize the conversation "
|
||||
"or when you want to free up context space."
|
||||
),
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[summarize_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your "
|
||||
"capabilities in a succinct way. Your output will be spoken aloud, so avoid "
|
||||
"special characters that can't easily be spoken, such as emojis or bullet points. "
|
||||
"Respond to what the user said in a creative and helpful way. "
|
||||
"If the user asks you to summarize the conversation, call the "
|
||||
"summarize_conversation function. After summarization, briefly acknowledge "
|
||||
"that the conversation history has been compressed."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools=tools)
|
||||
|
||||
# Automatic summarization is NOT enabled here (enable_auto_context_summarization
|
||||
# defaults to False). The summarizer is still created internally so that
|
||||
# LLMSummarizeContextFrame frames pushed via the function call are handled.
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
236
examples/foundational/54c-context-summarization-dedicated-llm.py
Normal file
236
examples/foundational/54c-context-summarization-dedicated-llm.py
Normal file
@@ -0,0 +1,236 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example demonstrating advanced context summarization configuration.
|
||||
|
||||
This example shows how to customize context summarization with:
|
||||
- A dedicated cheap/fast LLM for generating summaries (Gemini Flash)
|
||||
- A custom summary message template (XML tags)
|
||||
- A custom summarization prompt
|
||||
- A summarization timeout
|
||||
- The on_summary_applied event for observability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMAutoContextSummarizationConfig,
|
||||
LLMContextSummaryConfig,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
# Custom summarization prompt tailored to the application
|
||||
CUSTOM_SUMMARIZATION_PROMPT = """Summarize this conversation, preserving:
|
||||
- Key decisions and agreements
|
||||
- Important facts and user preferences
|
||||
- Any pending action items or unresolved questions
|
||||
|
||||
Be concise. Use clear, factual statements grouped by topic.
|
||||
Omit greetings, small talk, and resolved tangents."""
|
||||
|
||||
|
||||
# Tool functions for the LLM
|
||||
async def get_current_weather(params: FunctionCallParams):
|
||||
"""Get the current weather."""
|
||||
logger.info("Tool called: get_current_weather")
|
||||
await asyncio.sleep(1) # Simulate some processing
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Primary LLM for conversation (could be any provider)
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Dedicated cheap/fast LLM for summarization only
|
||||
summarization_llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash",
|
||||
)
|
||||
|
||||
# Register tool functions
|
||||
llm.register_function("get_current_weather", get_current_weather)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful LLM in a WebRTC call. Your goal is to demonstrate "
|
||||
"your capabilities in a succinct way. Your output will be spoken aloud, "
|
||||
"so avoid special characters that can't easily be spoken. Respond to what "
|
||||
"the user said in a creative and helpful way. You have access to tools to "
|
||||
"get the current weather - use them when relevant.\n\n"
|
||||
"When you see a <context_summary> block, it contains a compressed summary "
|
||||
"of earlier conversation. Use it as reference but don't mention it to the user."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools=tools)
|
||||
|
||||
# Create aggregators with custom summarization
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
assistant_params=LLMAssistantAggregatorParams(
|
||||
enable_auto_context_summarization=True,
|
||||
auto_context_summarization_config=LLMAutoContextSummarizationConfig(
|
||||
# Trigger thresholds (low values to demonstrate quickly)
|
||||
max_context_tokens=1000,
|
||||
max_unsummarized_messages=10,
|
||||
summary_config=LLMContextSummaryConfig(
|
||||
# Summary generation
|
||||
target_context_tokens=800,
|
||||
min_messages_after_summary=2,
|
||||
summarization_prompt=CUSTOM_SUMMARIZATION_PROMPT,
|
||||
# Custom summary format - wrap in XML tags so the system
|
||||
# prompt can identify summaries vs. live conversation
|
||||
summary_message_template="<context_summary>\n{summary}\n</context_summary>",
|
||||
# Use a dedicated cheap LLM for summarization instead of
|
||||
# the primary conversation model
|
||||
llm=summarization_llm,
|
||||
# Cancel summarization if it takes longer than 60 seconds
|
||||
summarization_timeout=60.0,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for summarization events
|
||||
summarizer = assistant_aggregator._summarizer
|
||||
if summarizer:
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(
|
||||
f"Context summarized: {event.original_message_count} messages -> "
|
||||
f"{event.new_message_count} messages "
|
||||
f"({event.summarized_message_count} summarized, "
|
||||
f"{event.preserved_message_count} preserved)"
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -24,7 +24,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt_sagemaker import (
|
||||
from pipecat.services.deepgram.sagemaker.stt import (
|
||||
DeepgramSageMakerSTTService,
|
||||
DeepgramSageMakerSTTSettings,
|
||||
)
|
||||
|
||||
@@ -22,10 +22,10 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.assemblyai.models import AssemblyAIConnectionParams
|
||||
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -51,7 +51,12 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = AssemblyAISTTService(api_key=os.getenv("ASSEMBLYAI_API_KEY"))
|
||||
stt = AssemblyAISTTService(
|
||||
api_key=os.getenv("ASSEMBLYAI_API_KEY"),
|
||||
connection_params=AssemblyAIConnectionParams(
|
||||
speech_model="u3-rt-pro",
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
@@ -63,7 +68,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
"content": "You are a helpful LLM in a WebRTC call demonstrating dynamic keyterms updates. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Try saying difficult names like 'Xiomara', 'Saoirse', or 'Krzystof' to test transcription accuracy.",
|
||||
},
|
||||
]
|
||||
|
||||
@@ -97,14 +102,24 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info(
|
||||
"Phase 1: No keyterms boosting - try saying 'Xiomara', 'Saoirse', or 'Krzystof'"
|
||||
)
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
await asyncio.sleep(10)
|
||||
logger.info("Updating AssemblyAI STT settings: language=es")
|
||||
await asyncio.sleep(15)
|
||||
logger.info("🔄 Updating keyterms: Adding difficult names for boosting")
|
||||
await task.queue_frame(
|
||||
STTUpdateSettingsFrame(delta=AssemblyAISTTSettings(language=Language.ES))
|
||||
STTUpdateSettingsFrame(
|
||||
delta=AssemblyAISTTSettings(
|
||||
connection_params=AssemblyAIConnectionParams(
|
||||
keyterms_prompt=["Xiomara", "Saoirse", "Krzystof", "Nguyen", "Pipecat"]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.info("Phase 2: Keyterms active - same names should transcribe better now!")
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -22,11 +22,11 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.deepgram.tts_sagemaker import (
|
||||
from pipecat.services.deepgram.sagemaker.tts import (
|
||||
DeepgramSageMakerTTSService,
|
||||
DeepgramSageMakerTTSSettings,
|
||||
)
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
123
examples/foundational/56-lemonslice-transport.py
Normal file
123
examples/foundational/56-lemonslice-transport.py
Normal file
@@ -0,0 +1,123 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.groq.llm import GroqLLMService
|
||||
from pipecat.transports.lemonslice.transport import (
|
||||
LemonSliceNewSessionRequest,
|
||||
LemonSliceParams,
|
||||
LemonSliceTransport,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = LemonSliceTransport(
|
||||
bot_name="Pipecat",
|
||||
api_key=os.getenv("LEMONSLICE_API_KEY"),
|
||||
session=session,
|
||||
session_request=LemonSliceNewSessionRequest(
|
||||
agent_id=os.getenv("LEMONSLICE_AGENT_ID"),
|
||||
),
|
||||
params=LemonSliceParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
microphone_out_enabled=False,
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"))
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY", ""),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
logger.info("Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Start by greeting the user and ask how you can help.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, participant):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -121,6 +121,7 @@ uv run 07-interruptible.py -t twilio -x NGROK_HOST_NAME
|
||||
- **[19-openai-realtime-beta.py](./19-openai-realtime-beta.py)**: OpenAI Speech-to-Speech (Direct S2S, Function calls)
|
||||
- **[21-tavus-layer-tavus-transport.py](./21-tavus-layer-tavus-transport.py)**: Tavus digital twin (Avatar integration)
|
||||
- **[27-simli-layer.py](./27-simli-layer.py)**: Simli avatar integration (Video synchronization)
|
||||
- **[56-lemonslice-transport.py](./56-lemonslice-transport.py)**: LemonSlice avatar integration (A/V Synced Avatar integration)
|
||||
|
||||
### Performance & Optimization
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ dependencies = [
|
||||
"soxr~=0.5.0",
|
||||
"openai>=1.74.0,<3",
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
"numba>=0.61.2",
|
||||
"wait_for2>=0.4.1; python_version<'3.12'",
|
||||
# Required by LocalSmartTurnAnalyzerV3
|
||||
# Inlined here instead of using a self-referential extra for Poetry compatibility.
|
||||
@@ -82,6 +82,7 @@ koala = [ "pvkoala~=2.0.3" ]
|
||||
kokoro = [ "kokoro-onnx>=0.5.0,<1", "requests>=2.32.5,<3" ]
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
lemonslice = [ "pipecat-ai[daily]" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0", "pyjwt>=2.10.1" ]
|
||||
lmnt = [ "pipecat-ai[websockets-base]" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
|
||||
@@ -14,12 +14,16 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soxr
|
||||
from loguru import logger
|
||||
from transformers import WhisperFeatureExtractor
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
from pipecat.utils.env import env_truthy
|
||||
|
||||
# The Whisper-based ONNX model expects 16 kHz audio input.
|
||||
_MODEL_SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
"""Local turn analyzer using the smart-turn-v3 ONNX model.
|
||||
@@ -77,7 +81,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
|
||||
def _write_audio_to_wav(
|
||||
self, audio_array: np.ndarray, sample_rate: int = 16000, suffix: str = ""
|
||||
self, audio_array: np.ndarray, sample_rate: int = _MODEL_SAMPLE_RATE, suffix: str = ""
|
||||
) -> None:
|
||||
"""Write audio data to a WAV file in a background thread.
|
||||
|
||||
@@ -119,10 +123,27 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
thread = threading.Thread(target=write_wav, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _resample_to_model_rate(self, audio_array: np.ndarray) -> np.ndarray:
|
||||
"""Resample audio to the model's expected sample rate (16 kHz).
|
||||
|
||||
Args:
|
||||
audio_array: Audio data as a float32 numpy array.
|
||||
|
||||
Returns:
|
||||
Resampled audio array at 16 kHz.
|
||||
"""
|
||||
actual_rate = self._sample_rate or _MODEL_SAMPLE_RATE
|
||||
if actual_rate == _MODEL_SAMPLE_RATE:
|
||||
return audio_array
|
||||
|
||||
return soxr.resample(audio_array, actual_rate, _MODEL_SAMPLE_RATE, quality="VHQ")
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000):
|
||||
def truncate_audio_to_last_n_seconds(
|
||||
audio_array, n_seconds=8, sample_rate=_MODEL_SAMPLE_RATE
|
||||
):
|
||||
"""Truncate audio to last n seconds or pad with zeros to meet n seconds."""
|
||||
max_samples = n_seconds * sample_rate
|
||||
if len(audio_array) > max_samples:
|
||||
@@ -134,6 +155,10 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
return audio_array
|
||||
|
||||
audio_for_logging = audio_array
|
||||
actual_rate = self._sample_rate or _MODEL_SAMPLE_RATE
|
||||
|
||||
# Resample to 16 kHz if the pipeline uses a different sample rate
|
||||
audio_array = self._resample_to_model_rate(audio_array)
|
||||
|
||||
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
||||
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
||||
@@ -141,10 +166,10 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
# Process audio using Whisper's feature extractor
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
sampling_rate=_MODEL_SAMPLE_RATE,
|
||||
return_tensors="np",
|
||||
padding="max_length",
|
||||
max_length=8 * 16000,
|
||||
max_length=8 * _MODEL_SAMPLE_RATE,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
@@ -164,7 +189,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
if self._log_data:
|
||||
suffix = "_complete" if prediction == 1 else "_incomplete"
|
||||
self._write_audio_to_wav(audio_for_logging, sample_rate=16000, suffix=suffix)
|
||||
self._write_audio_to_wav(audio_for_logging, sample_rate=actual_rate, suffix=suffix)
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
|
||||
@@ -368,7 +368,7 @@ class ClassificationProcessor(FrameProcessor):
|
||||
await self._voicemail_notifier.notify() # Clear buffered TTS frames
|
||||
|
||||
# Interrupt the current pipeline to stop any ongoing processing
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
# Set the voicemail event to trigger the voicemail handler
|
||||
self._voicemail_event.clear()
|
||||
|
||||
@@ -11,7 +11,6 @@ including data frames, system frames, and control frames for audio, video, text,
|
||||
and LLM processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
@@ -43,6 +42,7 @@ if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.services.settings import ServiceSettings
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig
|
||||
from pipecat.utils.tracing.tracing_context import TracingContext
|
||||
|
||||
|
||||
@@ -1140,24 +1140,9 @@ class InterruptionFrame(SystemFrame):
|
||||
This frame is used to interrupt the pipeline. For example, when a user
|
||||
starts speaking to cancel any in-progress bot output. It can also be pushed
|
||||
by any processor.
|
||||
|
||||
Parameters:
|
||||
event: Optional event set when the frame has fully traversed the
|
||||
pipeline.
|
||||
|
||||
"""
|
||||
|
||||
event: Optional[asyncio.Event] = None
|
||||
|
||||
def complete(self):
|
||||
"""Signal that this interruption has been fully processed.
|
||||
|
||||
Called automatically when the frame reaches the pipeline sink, or
|
||||
manually when the frame is consumed before reaching it (e.g. when
|
||||
the user is muted).
|
||||
"""
|
||||
if self.event:
|
||||
self.event.set()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1824,16 +1809,11 @@ class InterruptionTaskFrame(TaskFrame):
|
||||
"""Frame indicating the pipeline should be interrupted.
|
||||
|
||||
This frame should be pushed upstream to indicate the pipeline should be
|
||||
interrupted. The pipeline task converts this into an `InterruptionFrame` and
|
||||
sends it downstream. The `event` is passed to the `InterruptionFrame` so it
|
||||
can signal when the interruption has fully traversed the pipeline.
|
||||
|
||||
Parameters:
|
||||
event: Optional event passed to the corresponding `InterruptionFrame`.
|
||||
|
||||
interrupted. The pipeline task converts this into an `InterruptionFrame`
|
||||
and sends it downstream.
|
||||
"""
|
||||
|
||||
event: Optional[asyncio.Event] = None
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1909,6 +1889,29 @@ class StopFrame(ControlFrame, UninterruptibleFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotConnectedFrame(SystemFrame):
|
||||
"""Frame indicating the bot has connected to the transport service.
|
||||
|
||||
Pushed downstream by SFU transports (Daily, LiveKit, HeyGen, Tavus)
|
||||
when the bot successfully joins the room. Non-SFU transports do not
|
||||
emit this frame.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConnectedFrame(SystemFrame):
|
||||
"""Frame indicating that a client has connected to the transport.
|
||||
|
||||
Pushed downstream by the input transport when a client (participant)
|
||||
connects. Used by observers to measure transport readiness timing.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputTransportReadyFrame(ControlFrame):
|
||||
"""Frame indicating that the output transport is ready.
|
||||
@@ -1990,6 +1993,32 @@ class LLMFullResponseEndFrame(ControlFrame):
|
||||
self.skip_tts = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMAssistantPushAggregationFrame(ControlFrame):
|
||||
"""Frame that forces the LLM assistant aggregator to push its current aggregation to context.
|
||||
|
||||
When received by ``LLMAssistantAggregator``, any text that has been accumulated
|
||||
in the aggregation buffer is immediately committed to the conversation context as
|
||||
an assistant message, without waiting for an ``LLMFullResponseEndFrame``.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSummarizeContextFrame(ControlFrame):
|
||||
"""Frame requesting on-demand context summarization.
|
||||
|
||||
Push this frame into the pipeline to trigger a manual context summarization.
|
||||
|
||||
Parameters:
|
||||
config: Optional per-request override for summary generation settings
|
||||
(prompt, token budget, messages to keep). If ``None``, the
|
||||
summarizer's default :class:`~pipecat.utils.context.llm_context_summarization.LLMContextSummaryConfig`
|
||||
is used.
|
||||
"""
|
||||
|
||||
config: Optional["LLMContextSummaryConfig"] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextSummaryRequestFrame(ControlFrame):
|
||||
"""Frame requesting context summarization from an LLM service.
|
||||
@@ -2009,6 +2038,8 @@ class LLMContextSummaryRequestFrame(ControlFrame):
|
||||
the summary text.
|
||||
summarization_prompt: System prompt instructing the LLM how to generate
|
||||
the summary.
|
||||
summarization_timeout: Maximum time in seconds for the LLM to generate a
|
||||
summary. When None, a default timeout of 120s is applied.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
@@ -2016,6 +2047,7 @@ class LLMContextSummaryRequestFrame(ControlFrame):
|
||||
min_messages_to_keep: int
|
||||
target_context_tokens: int
|
||||
summarization_prompt: str
|
||||
summarization_timeout: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -100,3 +100,11 @@ class BaseObserver(BaseObject):
|
||||
data: The event data containing details about the frame transfer.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_pipeline_started(self):
|
||||
"""Called when the pipeline has fully started.
|
||||
|
||||
Fired after the ``StartFrame`` has been processed by all processors
|
||||
in the pipeline, including nested ``ParallelPipeline`` branches.
|
||||
"""
|
||||
pass
|
||||
|
||||
328
src/pipecat/observers/startup_timing_observer.py
Normal file
328
src/pipecat/observers/startup_timing_observer.py
Normal file
@@ -0,0 +1,328 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Observer for tracking pipeline startup timing.
|
||||
|
||||
This module provides an observer that measures how long each processor's
|
||||
``start()`` method takes during pipeline startup. It works by tracking
|
||||
when a ``StartFrame`` arrives at a processor (``on_process_frame``) versus
|
||||
when it leaves (``on_push_frame``), giving the exact ``start()`` duration
|
||||
for each processor in the pipeline.
|
||||
|
||||
It also measures transport timing — the time from ``StartFrame`` to the
|
||||
first ``BotConnectedFrame`` (SFU transports only) and ``ClientConnectedFrame``
|
||||
— via a separate ``on_transport_timing_report`` event.
|
||||
|
||||
Example::
|
||||
|
||||
observer = StartupTimingObserver()
|
||||
|
||||
@observer.event_handler("on_startup_timing_report")
|
||||
async def on_report(observer, report):
|
||||
for t in report.processor_timings:
|
||||
print(f"{t.processor_name}: {t.duration_secs:.3f}s")
|
||||
|
||||
@observer.event_handler("on_transport_timing_report")
|
||||
async def on_transport(observer, report):
|
||||
if report.bot_connected_secs is not None:
|
||||
print(f"Bot connected in {report.bot_connected_secs:.3f}s")
|
||||
print(f"Client connected in {report.client_connected_secs:.3f}s")
|
||||
|
||||
task = PipelineTask(pipeline, observers=[observer])
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import BotConnectedFrame, ClientConnectedFrame, StartFrame
|
||||
from pipecat.observers.base_observer import BaseObserver, FrameProcessed, FramePushed
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.pipeline import PipelineSource
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
# Internal pipeline types excluded from tracking by default.
|
||||
_INTERNAL_TYPES = (PipelineSource, BasePipeline)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ArrivalInfo:
|
||||
"""Internal record of when a StartFrame arrived at a processor."""
|
||||
|
||||
processor: FrameProcessor
|
||||
arrival_ts_ns: int
|
||||
|
||||
|
||||
class ProcessorStartupTiming(BaseModel):
|
||||
"""Startup timing for a single processor.
|
||||
|
||||
Parameters:
|
||||
processor_name: The name of the processor.
|
||||
start_offset_secs: Offset in seconds from the StartFrame to when this
|
||||
processor's start() began.
|
||||
duration_secs: How long the processor's start() took, in seconds.
|
||||
"""
|
||||
|
||||
processor_name: str
|
||||
start_offset_secs: float
|
||||
duration_secs: float
|
||||
|
||||
|
||||
class StartupTimingReport(BaseModel):
|
||||
"""Report of startup timings for all measured processors.
|
||||
|
||||
Parameters:
|
||||
start_time: Unix timestamp when the first processor began starting.
|
||||
total_duration_secs: Total wall-clock time from first to last processor start.
|
||||
processor_timings: Per-processor timing data, in pipeline order.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
total_duration_secs: float
|
||||
processor_timings: List[ProcessorStartupTiming] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TransportTimingReport(BaseModel):
|
||||
"""Time from pipeline start to transport connection milestones.
|
||||
|
||||
Parameters:
|
||||
start_time: Unix timestamp of the StartFrame (pipeline start).
|
||||
bot_connected_secs: Seconds from StartFrame to first BotConnectedFrame
|
||||
(only set for SFU transports).
|
||||
client_connected_secs: Seconds from StartFrame to first ClientConnectedFrame.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
bot_connected_secs: Optional[float] = None
|
||||
client_connected_secs: Optional[float] = None
|
||||
|
||||
|
||||
class StartupTimingObserver(BaseObserver):
|
||||
"""Observer that measures processor startup times during pipeline initialization.
|
||||
|
||||
Tracks how long each processor's ``start()`` method takes by measuring the
|
||||
time between when a ``StartFrame`` arrives at a processor and when it is
|
||||
pushed downstream. This captures WebSocket connections, API authentication,
|
||||
model loading, and other initialization work.
|
||||
|
||||
Also measures transport timing, the time from ``StartFrame`` to connection
|
||||
milestones:
|
||||
|
||||
- ``bot_connected_secs``: When the bot joins the transport room
|
||||
(SFU transports only, triggered by ``BotConnectedFrame``).
|
||||
- ``client_connected_secs``: When a remote participant connects
|
||||
(triggered by ``ClientConnectedFrame``).
|
||||
|
||||
By default, internal pipeline processors (``PipelineSource``, ``Pipeline``)
|
||||
are excluded from the report. Pass ``processor_types`` to measure only
|
||||
specific types.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_startup_timing_report: Called once after startup completes with the full
|
||||
timing report.
|
||||
- on_transport_timing_report: Called once when the first client connects with a
|
||||
TransportTimingReport containing client_connected_secs and bot_connected_secs
|
||||
(if available).
|
||||
|
||||
Example::
|
||||
|
||||
observer = StartupTimingObserver(
|
||||
processor_types=(STTService, TTSService)
|
||||
)
|
||||
|
||||
@observer.event_handler("on_startup_timing_report")
|
||||
async def on_report(observer, report):
|
||||
for t in report.processor_timings:
|
||||
logger.info(f"{t.processor_name}: {t.duration_secs:.3f}s")
|
||||
|
||||
@observer.event_handler("on_transport_timing_report")
|
||||
async def on_transport(observer, report):
|
||||
if report.bot_connected_secs is not None:
|
||||
logger.info(f"Bot connected in {report.bot_connected_secs:.3f}s")
|
||||
logger.info(f"Client connected in {report.client_connected_secs:.3f}s")
|
||||
|
||||
task = PipelineTask(pipeline, observers=[observer])
|
||||
|
||||
Args:
|
||||
processor_types: Optional tuple of processor types to measure. If None,
|
||||
all non-internal processors are measured.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
processor_types: Optional[Tuple[Type[FrameProcessor], ...]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the startup timing observer.
|
||||
|
||||
Args:
|
||||
processor_types: Optional tuple of processor types to measure.
|
||||
If None, all non-internal processors are measured.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._processor_types = processor_types
|
||||
|
||||
# Map processor ID -> arrival info.
|
||||
self._arrivals: Dict[int, _ArrivalInfo] = {}
|
||||
|
||||
# Collected timings in pipeline order.
|
||||
self._timings: List[ProcessorStartupTiming] = []
|
||||
|
||||
# Lock onto the first StartFrame we see (by frame ID).
|
||||
self._start_frame_id: Optional[str] = None
|
||||
|
||||
# Whether we've already emitted the startup timing report.
|
||||
self._startup_timing_reported = False
|
||||
|
||||
# Whether we've already measured transport timing.
|
||||
self._transport_timing_reported = False
|
||||
|
||||
# Timestamp (ns) when we first see a StartFrame arrive at a processor.
|
||||
self._start_frame_arrival_ns: Optional[int] = None
|
||||
|
||||
# Bot connected timing (stored for inclusion in the transport report).
|
||||
self._bot_connected_secs: Optional[float] = None
|
||||
|
||||
# Wall clock time when the StartFrame was first seen.
|
||||
self._start_wall_clock: Optional[float] = None
|
||||
|
||||
self._register_event_handler("on_startup_timing_report")
|
||||
self._register_event_handler("on_transport_timing_report")
|
||||
|
||||
def _should_track(self, processor: FrameProcessor) -> bool:
|
||||
"""Check if a processor should be tracked for timing.
|
||||
|
||||
Args:
|
||||
processor: The processor to check.
|
||||
|
||||
Returns:
|
||||
True if the processor matches the filter or no filter is set.
|
||||
"""
|
||||
if self._processor_types is not None:
|
||||
return isinstance(processor, self._processor_types)
|
||||
# Default: exclude internal pipeline plumbing.
|
||||
return not isinstance(processor, _INTERNAL_TYPES)
|
||||
|
||||
async def on_pipeline_started(self):
|
||||
"""Emit the startup timing report when the pipeline has fully started.
|
||||
|
||||
Called by the ``PipelineTask`` after the ``StartFrame`` has been
|
||||
processed by all processors, including nested ``ParallelPipeline``
|
||||
branches.
|
||||
"""
|
||||
if self._timings:
|
||||
await self._emit_report()
|
||||
|
||||
async def on_process_frame(self, data: FrameProcessed):
|
||||
"""Record when a StartFrame arrives at a processor.
|
||||
|
||||
Args:
|
||||
data: The frame processing event data.
|
||||
"""
|
||||
if self._startup_timing_reported:
|
||||
return
|
||||
|
||||
if not isinstance(data.frame, StartFrame):
|
||||
return
|
||||
|
||||
# Lock onto the first StartFrame.
|
||||
if self._start_frame_id is None:
|
||||
self._start_frame_id = data.frame.id
|
||||
self._start_frame_arrival_ns = data.timestamp
|
||||
self._start_wall_clock = time.time()
|
||||
elif data.frame.id != self._start_frame_id:
|
||||
return
|
||||
|
||||
if self._should_track(data.processor):
|
||||
self._arrivals[data.processor.id] = _ArrivalInfo(
|
||||
processor=data.processor, arrival_ts_ns=data.timestamp
|
||||
)
|
||||
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
"""Record when a StartFrame leaves a processor and compute the delta.
|
||||
|
||||
Also handles ``BotConnectedFrame`` and ``ClientConnectedFrame`` to
|
||||
measure transport timing.
|
||||
|
||||
Args:
|
||||
data: The frame push event data.
|
||||
"""
|
||||
if isinstance(data.frame, BotConnectedFrame):
|
||||
self._handle_bot_connected(data)
|
||||
return
|
||||
|
||||
if isinstance(data.frame, ClientConnectedFrame):
|
||||
await self._handle_client_connected(data)
|
||||
return
|
||||
|
||||
if self._startup_timing_reported:
|
||||
return
|
||||
|
||||
if not isinstance(data.frame, StartFrame):
|
||||
return
|
||||
|
||||
if self._start_frame_id is not None and data.frame.id != self._start_frame_id:
|
||||
return
|
||||
|
||||
arrival = self._arrivals.pop(data.source.id, None)
|
||||
if arrival is None:
|
||||
return
|
||||
|
||||
duration_ns = data.timestamp - arrival.arrival_ts_ns
|
||||
duration_secs = duration_ns / 1e9
|
||||
start_offset_secs = (arrival.arrival_ts_ns - self._start_frame_arrival_ns) / 1e9
|
||||
|
||||
self._timings.append(
|
||||
ProcessorStartupTiming(
|
||||
processor_name=arrival.processor.name,
|
||||
start_offset_secs=start_offset_secs,
|
||||
duration_secs=duration_secs,
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_bot_connected(self, data: FramePushed):
|
||||
"""Record bot connected timing on first BotConnectedFrame."""
|
||||
if self._bot_connected_secs is not None or self._start_frame_arrival_ns is None:
|
||||
return
|
||||
|
||||
delta_ns = data.timestamp - self._start_frame_arrival_ns
|
||||
self._bot_connected_secs = delta_ns / 1e9
|
||||
|
||||
async def _handle_client_connected(self, data: FramePushed):
|
||||
"""Emit transport timing report on first ClientConnectedFrame."""
|
||||
if self._transport_timing_reported or self._start_frame_arrival_ns is None:
|
||||
return
|
||||
|
||||
self._transport_timing_reported = True
|
||||
delta_ns = data.timestamp - self._start_frame_arrival_ns
|
||||
client_connected_secs = delta_ns / 1e9
|
||||
report = TransportTimingReport(
|
||||
start_time=self._start_wall_clock or 0.0,
|
||||
bot_connected_secs=self._bot_connected_secs,
|
||||
client_connected_secs=client_connected_secs,
|
||||
)
|
||||
await self._call_event_handler("on_transport_timing_report", report)
|
||||
|
||||
async def _emit_report(self):
|
||||
"""Build and emit the startup timing report."""
|
||||
if self._startup_timing_reported:
|
||||
return
|
||||
self._startup_timing_reported = True
|
||||
|
||||
total = sum(t.duration_secs for t in self._timings)
|
||||
|
||||
report = StartupTimingReport(
|
||||
start_time=self._start_wall_clock or 0.0,
|
||||
total_duration_secs=total,
|
||||
processor_timings=self._timings,
|
||||
)
|
||||
|
||||
await self._call_event_handler("on_startup_timing_report", report)
|
||||
@@ -1,22 +1,146 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Observer for tracking user-to-bot response latency.
|
||||
|
||||
This module provides an observer that monitors the time between when a user
|
||||
stops speaking and when the bot starts speaking, emitting events when latency
|
||||
is measured.
|
||||
is measured. Optionally collects per-service latency breakdown metrics
|
||||
(TTFB, text aggregation) when ``enable_metrics=True``.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Set
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
ClientConnectedFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
MetricsFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import (
|
||||
TextAggregationMetricsData,
|
||||
TTFBMetricsData,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class TTFBBreakdownMetrics(BaseModel):
|
||||
"""TTFB measurement with timestamp for timeline placement.
|
||||
|
||||
Parameters:
|
||||
processor: Name of the processor that reported the TTFB.
|
||||
model: Optional model name associated with the metric.
|
||||
start_time: Unix timestamp when the TTFB measurement started.
|
||||
duration_secs: TTFB duration in seconds.
|
||||
"""
|
||||
|
||||
processor: str
|
||||
model: Optional[str] = None
|
||||
start_time: float
|
||||
duration_secs: float
|
||||
|
||||
|
||||
class TextAggregationBreakdownMetrics(BaseModel):
|
||||
"""Text aggregation measurement with timestamp for timeline placement.
|
||||
|
||||
Parameters:
|
||||
processor: Name of the processor that reported the metric.
|
||||
start_time: Unix timestamp when text aggregation started.
|
||||
duration_secs: Aggregation duration in seconds.
|
||||
"""
|
||||
|
||||
processor: str
|
||||
start_time: float
|
||||
duration_secs: float
|
||||
|
||||
|
||||
class FunctionCallMetrics(BaseModel):
|
||||
"""Latency for a single function call execution.
|
||||
|
||||
Parameters:
|
||||
function_name: Name of the function that was called.
|
||||
start_time: Unix timestamp when execution started.
|
||||
duration_secs: Time in seconds from execution start to result.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
start_time: float
|
||||
duration_secs: float
|
||||
|
||||
|
||||
class LatencyBreakdown(BaseModel):
|
||||
"""Per-service latency breakdown for a single user-to-bot cycle.
|
||||
|
||||
Collected between ``VADUserStoppedSpeakingFrame`` and
|
||||
``BotStartedSpeakingFrame`` when ``enable_metrics=True`` in
|
||||
:class:`~pipecat.pipeline.task.PipelineParams`.
|
||||
|
||||
Parameters:
|
||||
ttfb: Time-to-first-byte metrics from each service in the pipeline.
|
||||
text_aggregation: First text aggregation measurement, representing
|
||||
the latency cost of sentence aggregation in the TTS pipeline.
|
||||
user_turn_start_time: Unix timestamp when the user turn started
|
||||
(actual user silence, adjusted for VAD stop_secs). ``None`` if
|
||||
no ``VADUserStoppedSpeakingFrame`` was observed.
|
||||
user_turn_secs: Duration in seconds of the user's turn, measured
|
||||
from when the user actually stopped speaking to when the turn
|
||||
was released (``UserStoppedSpeakingFrame``). This includes
|
||||
VAD silence detection, STT finalization, and any turn analyzer
|
||||
wait. ``None`` if no ``UserStoppedSpeakingFrame`` was observed
|
||||
(e.g. no turn analyzer configured).
|
||||
function_calls: Latency for each function call executed during
|
||||
this cycle. Empty if no function calls occurred.
|
||||
"""
|
||||
|
||||
ttfb: List[TTFBBreakdownMetrics] = Field(default_factory=list)
|
||||
text_aggregation: Optional[TextAggregationBreakdownMetrics] = None
|
||||
user_turn_start_time: Optional[float] = None
|
||||
user_turn_secs: Optional[float] = None
|
||||
function_calls: List[FunctionCallMetrics] = Field(default_factory=list)
|
||||
|
||||
def chronological_events(self) -> List[str]:
|
||||
"""Return human-readable event labels sorted by start time.
|
||||
|
||||
Collects all sub-metrics into a flat list, sorts by ``start_time``,
|
||||
and returns formatted strings suitable for logging.
|
||||
|
||||
Returns:
|
||||
List of formatted strings, one per event, in chronological order.
|
||||
"""
|
||||
events: List[tuple] = []
|
||||
|
||||
if self.user_turn_start_time is not None and self.user_turn_secs is not None:
|
||||
events.append((self.user_turn_start_time, f"User turn: {self.user_turn_secs:.3f}s"))
|
||||
|
||||
for t in self.ttfb:
|
||||
events.append((t.start_time, f"{t.processor}: TTFB {t.duration_secs:.3f}s"))
|
||||
|
||||
for fc in self.function_calls:
|
||||
events.append((fc.start_time, f"{fc.function_name}: {fc.duration_secs:.3f}s"))
|
||||
|
||||
if self.text_aggregation:
|
||||
ta = self.text_aggregation
|
||||
events.append(
|
||||
(ta.start_time, f"{ta.processor}: text aggregation {ta.duration_secs:.3f}s")
|
||||
)
|
||||
|
||||
events.sort(key=lambda e: e[0])
|
||||
return [label for _, label in events]
|
||||
|
||||
|
||||
class UserBotLatencyObserver(BaseObserver):
|
||||
"""Observer that tracks user-to-bot response latency.
|
||||
|
||||
@@ -25,34 +149,66 @@ class UserBotLatencyObserver(BaseObserver):
|
||||
latency is measured, allowing consumers to log, trace, or otherwise process
|
||||
the latency data.
|
||||
|
||||
When ``enable_metrics=True`` in pipeline params, also collects per-service
|
||||
latency breakdown (TTFB, text aggregation) and emits an
|
||||
``on_latency_breakdown`` event alongside the existing latency measurement.
|
||||
|
||||
This observer follows the composition pattern used by TurnTrackingObserver,
|
||||
acting as a reusable component for latency measurement.
|
||||
|
||||
Events:
|
||||
on_latency_measured(observer, latency_seconds): Emitted when user-to-bot
|
||||
latency is calculated. Includes the latency value in seconds as a float.
|
||||
on_latency_measured(observer, latency_seconds): Emitted when
|
||||
time-to-first-bot-speech is calculated. Measures the time from
|
||||
when the user stopped speaking to when the bot starts speaking.
|
||||
on_latency_breakdown(observer, breakdown): Emitted at each
|
||||
``BotStartedSpeakingFrame`` with a :class:`LatencyBreakdown`
|
||||
containing per-service metrics collected during the user→bot cycle.
|
||||
on_first_bot_speech_latency(observer, latency_seconds): Emitted once,
|
||||
the first time ``BotStartedSpeakingFrame`` arrives after
|
||||
``ClientConnectedFrame``. Measures the time from client connection
|
||||
to the first bot speech.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, *, max_frames=100, **kwargs):
|
||||
"""Initialize the user-bot latency observer.
|
||||
|
||||
Sets up tracking for processed frames and user speech timing
|
||||
to calculate response latencies.
|
||||
|
||||
Args:
|
||||
max_frames: Maximum number of frame IDs to keep in history for
|
||||
duplicate detection. Defaults to 100.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._user_stopped_time: Optional[float] = None
|
||||
self._processed_frames: Set[str] = set()
|
||||
self._user_turn_start_time: Optional[float] = None
|
||||
self._user_turn: Optional[float] = None
|
||||
|
||||
# First bot speech tracking
|
||||
self._client_connected_time: Optional[float] = None
|
||||
self._first_bot_speech_measured: bool = False
|
||||
|
||||
# Frame deduplication (bounded deque + set pattern)
|
||||
self._processed_frames: set = set()
|
||||
self._frame_history: deque = deque(maxlen=max_frames)
|
||||
|
||||
# Per-cycle metric accumulators
|
||||
self._ttfb: List[TTFBBreakdownMetrics] = []
|
||||
self._text_aggregation: Optional[TextAggregationBreakdownMetrics] = None
|
||||
self._function_call_starts: Dict[str, tuple[str, float]] = {}
|
||||
self._function_call_metrics: List[FunctionCallMetrics] = []
|
||||
|
||||
self._register_event_handler("on_latency_measured")
|
||||
self._register_event_handler("on_latency_breakdown")
|
||||
self._register_event_handler("on_first_bot_speech_latency")
|
||||
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
"""Process frames to track speech timing and calculate latency.
|
||||
|
||||
Tracks VAD events and bot speaking events to measure the time between
|
||||
user stopping speech and bot starting speech.
|
||||
user stopping speech and bot starting speech. Also accumulates metrics
|
||||
from MetricsFrame for the latency breakdown.
|
||||
|
||||
Args:
|
||||
data: Frame push event containing the frame and direction information.
|
||||
@@ -61,23 +217,135 @@ class UserBotLatencyObserver(BaseObserver):
|
||||
if data.direction != FrameDirection.DOWNSTREAM:
|
||||
return
|
||||
|
||||
# Skip already processed frames
|
||||
# Skip already processed frames (bounded deque + set)
|
||||
if data.frame.id in self._processed_frames:
|
||||
return
|
||||
|
||||
self._processed_frames.add(data.frame.id)
|
||||
self._frame_history.append(data.frame.id)
|
||||
|
||||
# Track VAD and bot speaking events for latency
|
||||
if len(self._processed_frames) > len(self._frame_history):
|
||||
self._processed_frames = set(self._frame_history)
|
||||
|
||||
# Track client connection (first occurrence only)
|
||||
if isinstance(data.frame, ClientConnectedFrame):
|
||||
if self._client_connected_time is None:
|
||||
self._client_connected_time = time.time()
|
||||
return
|
||||
|
||||
# Track speech and pipeline events for latency
|
||||
if isinstance(data.frame, VADUserStartedSpeakingFrame):
|
||||
# Reset when user starts speaking
|
||||
self._user_stopped_time = None
|
||||
self._user_turn_start_time = None
|
||||
self._user_turn = None
|
||||
self._reset_accumulators()
|
||||
# If user speaks before the bot's first speech, abandon the
|
||||
# first-bot-speech measurement — it's only meaningful for greetings.
|
||||
self._first_bot_speech_measured = True
|
||||
elif isinstance(data.frame, VADUserStoppedSpeakingFrame):
|
||||
# Record the actual time the user stopped speaking, which is
|
||||
# the VAD determination time minus the stop_secs silence duration
|
||||
# that had to elapse before the VAD confirmed speech ended.
|
||||
self._user_stopped_time = data.frame.timestamp - data.frame.stop_secs
|
||||
elif isinstance(data.frame, BotStartedSpeakingFrame) and self._user_stopped_time:
|
||||
# Calculate and emit latency
|
||||
self._user_turn_start_time = self._user_stopped_time
|
||||
elif isinstance(data.frame, UserStoppedSpeakingFrame):
|
||||
# Measure the user turn duration: from actual user silence to
|
||||
# turn release. Includes VAD silence detection, STT finalization,
|
||||
# and any turn analyzer wait.
|
||||
if self._user_stopped_time is not None:
|
||||
self._user_turn = time.time() - self._user_stopped_time
|
||||
elif isinstance(data.frame, InterruptionFrame):
|
||||
# Discard stale metrics from cancelled LLM/TTS cycles
|
||||
self._reset_accumulators()
|
||||
elif isinstance(data.frame, FunctionCallInProgressFrame):
|
||||
self._function_call_starts[data.frame.tool_call_id] = (
|
||||
data.frame.function_name,
|
||||
time.time(),
|
||||
)
|
||||
elif isinstance(data.frame, FunctionCallResultFrame):
|
||||
start = self._function_call_starts.pop(data.frame.tool_call_id, None)
|
||||
if start is not None:
|
||||
function_name, start_time = start
|
||||
self._function_call_metrics.append(
|
||||
FunctionCallMetrics(
|
||||
function_name=function_name,
|
||||
start_time=start_time,
|
||||
duration_secs=time.time() - start_time,
|
||||
)
|
||||
)
|
||||
elif isinstance(data.frame, MetricsFrame):
|
||||
self._handle_metrics_frame(data.frame)
|
||||
elif isinstance(data.frame, BotStartedSpeakingFrame):
|
||||
await self._handle_bot_started_speaking()
|
||||
|
||||
async def _handle_bot_started_speaking(self):
|
||||
"""Handle BotStartedSpeakingFrame to emit latency and breakdown."""
|
||||
emit_breakdown = False
|
||||
|
||||
# One-time first bot speech measurement (client connect → first speech)
|
||||
if self._client_connected_time is not None and not self._first_bot_speech_measured:
|
||||
self._first_bot_speech_measured = True
|
||||
latency = time.time() - self._client_connected_time
|
||||
await self._call_event_handler("on_first_bot_speech_latency", latency)
|
||||
emit_breakdown = True
|
||||
|
||||
if self._user_stopped_time is not None:
|
||||
latency = time.time() - self._user_stopped_time
|
||||
self._user_stopped_time = None
|
||||
await self._call_event_handler("on_latency_measured", latency)
|
||||
emit_breakdown = True
|
||||
|
||||
if emit_breakdown:
|
||||
breakdown = LatencyBreakdown(
|
||||
ttfb=list(self._ttfb),
|
||||
text_aggregation=self._text_aggregation,
|
||||
user_turn_start_time=self._user_turn_start_time,
|
||||
user_turn_secs=self._user_turn,
|
||||
function_calls=list(self._function_call_metrics),
|
||||
)
|
||||
await self._call_event_handler("on_latency_breakdown", breakdown)
|
||||
self._reset_accumulators()
|
||||
|
||||
def _handle_metrics_frame(self, frame: MetricsFrame):
|
||||
"""Extract latency metrics from a MetricsFrame.
|
||||
|
||||
Accumulates metrics when a measurement is in progress: either a
|
||||
user→bot cycle (after ``VADUserStoppedSpeakingFrame``) or the
|
||||
first-bot-speech window (after ``ClientConnectedFrame``).
|
||||
"""
|
||||
waiting_for_first_speech = (
|
||||
self._client_connected_time is not None and not self._first_bot_speech_measured
|
||||
)
|
||||
if self._user_stopped_time is None and not waiting_for_first_speech:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
for metrics_data in frame.data:
|
||||
if isinstance(metrics_data, TTFBMetricsData) and metrics_data.value > 0:
|
||||
self._ttfb.append(
|
||||
TTFBBreakdownMetrics(
|
||||
processor=metrics_data.processor,
|
||||
model=metrics_data.model,
|
||||
start_time=now - metrics_data.value,
|
||||
duration_secs=metrics_data.value,
|
||||
)
|
||||
)
|
||||
elif isinstance(metrics_data, TextAggregationMetricsData):
|
||||
# Only keep the first measurement — it's the one that
|
||||
# impacts the initial speaking latency.
|
||||
if self._text_aggregation is None:
|
||||
self._text_aggregation = TextAggregationBreakdownMetrics(
|
||||
processor=metrics_data.processor,
|
||||
start_time=now - metrics_data.value,
|
||||
duration_secs=metrics_data.value,
|
||||
)
|
||||
|
||||
def _reset_accumulators(self):
|
||||
"""Clear per-cycle metric accumulators."""
|
||||
self._ttfb = []
|
||||
self._text_aggregation = None
|
||||
self._user_turn_start_time = None
|
||||
self._user_turn = None
|
||||
self._function_call_starts = {}
|
||||
self._function_call_metrics = []
|
||||
|
||||
@@ -330,6 +330,7 @@ class PipelineTask(BasePipelineTask):
|
||||
|
||||
# RTVI support
|
||||
self._rtvi = None
|
||||
prepend_rtvi = False
|
||||
external_rtvi = self._find_processor(pipeline, RTVIProcessor)
|
||||
external_observer_found = any(isinstance(o, RTVIObserver) for o in observers)
|
||||
|
||||
@@ -352,6 +353,7 @@ class PipelineTask(BasePipelineTask):
|
||||
elif enable_rtvi:
|
||||
self._rtvi = rtvi_processor or RTVIProcessor()
|
||||
observers.append(self._rtvi.create_rtvi_observer(params=rtvi_observer_params))
|
||||
prepend_rtvi = True
|
||||
|
||||
if self._rtvi:
|
||||
# Automatically call RTVIProcessor.set_bot_ready()
|
||||
@@ -387,9 +389,12 @@ class PipelineTask(BasePipelineTask):
|
||||
# source allows us to receive and react to upstream frames, and the sink
|
||||
# allows us to receive and react to downstream frames.
|
||||
source = PipelineSource(self._source_push_frame, name=f"{self}::Source")
|
||||
sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
|
||||
processors = [self._rtvi, pipeline] if self._rtvi else [pipeline]
|
||||
self._pipeline = Pipeline(processors, source=source, sink=sink)
|
||||
self._sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
|
||||
# Only prepend the RTVIProcessor if we created it ourselves. When the
|
||||
# user already placed it inside their pipeline we must not insert it
|
||||
# again or it will appear twice in the frame chain.
|
||||
processors = [self._rtvi, pipeline] if prepend_rtvi else [pipeline]
|
||||
self._pipeline = Pipeline(processors, source=source, sink=self._sink)
|
||||
|
||||
# The task observer acts as a proxy to the provided observers. This way,
|
||||
# we only need to pass a single observer (using the StartFrame) which
|
||||
@@ -620,26 +625,43 @@ class PipelineTask(BasePipelineTask):
|
||||
self._finished = True
|
||||
logger.debug(f"Pipeline task {self} has finished")
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
"""Queue a single frame to be pushed down the pipeline.
|
||||
async def queue_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Queue a single frame to be pushed through the pipeline.
|
||||
|
||||
Downstream frames are pushed from the beginning of the pipeline.
|
||||
Upstream frames are pushed from the end of the pipeline.
|
||||
|
||||
Args:
|
||||
frame: The frame to be processed.
|
||||
direction: The direction to push the frame. Defaults to downstream.
|
||||
"""
|
||||
await self._push_queue.put(frame)
|
||||
if direction == FrameDirection.DOWNSTREAM:
|
||||
await self._push_queue.put(frame)
|
||||
else:
|
||||
await self._sink.queue_frame(frame, direction)
|
||||
|
||||
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
|
||||
"""Queues multiple frames to be pushed down the pipeline.
|
||||
async def queue_frames(
|
||||
self,
|
||||
frames: Iterable[Frame] | AsyncIterable[Frame],
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
):
|
||||
"""Queue multiple frames to be pushed through the pipeline.
|
||||
|
||||
Downstream frames are pushed from the beginning of the pipeline.
|
||||
Upstream frames are pushed from the end of the pipeline.
|
||||
|
||||
Args:
|
||||
frames: An iterable or async iterable of frames to be processed.
|
||||
direction: The direction to push the frames. Defaults to downstream.
|
||||
"""
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
await self.queue_frame(frame)
|
||||
await self.queue_frame(frame, direction)
|
||||
elif isinstance(frames, Iterable):
|
||||
for frame in frames:
|
||||
await self.queue_frame(frame)
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def _cancel(self, *, reason: Optional[str] = None):
|
||||
"""Internal cancellation logic for the pipeline task.
|
||||
@@ -870,7 +892,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame(event=frame.event))
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
await self._call_event_handler("on_pipeline_error", frame)
|
||||
if frame.fatal:
|
||||
@@ -893,6 +915,7 @@ class PipelineTask(BasePipelineTask):
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._call_event_handler("on_pipeline_started", frame)
|
||||
await self._observer.on_pipeline_started()
|
||||
|
||||
# Start heartbeat tasks now that StartFrame has been processed
|
||||
# by all processors in the pipeline
|
||||
@@ -909,8 +932,6 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._heartbeat_queue.put(frame)
|
||||
|
||||
|
||||
@@ -39,6 +39,12 @@ class Proxy:
|
||||
observer: BaseObserver
|
||||
|
||||
|
||||
class _PipelineStartedSignal:
|
||||
"""Internal sentinel queued to observers when the pipeline has started."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TaskObserver(BaseObserver):
|
||||
"""Proxy observer that manages multiple observers without blocking the pipeline.
|
||||
|
||||
@@ -129,6 +135,10 @@ class TaskObserver(BaseObserver):
|
||||
for proxy in self._proxies:
|
||||
await proxy.cleanup()
|
||||
|
||||
async def on_pipeline_started(self):
|
||||
"""Forward pipeline started signal to all managed observers."""
|
||||
await self._send_to_proxy(_PipelineStartedSignal())
|
||||
|
||||
async def on_process_frame(self, data: FrameProcessed):
|
||||
"""Queue frame data for all managed observers.
|
||||
|
||||
@@ -186,7 +196,9 @@ class TaskObserver(BaseObserver):
|
||||
while True:
|
||||
data = await queue.get()
|
||||
|
||||
if isinstance(data, FramePushed):
|
||||
if isinstance(data, _PipelineStartedSignal):
|
||||
await observer.on_pipeline_started()
|
||||
elif isinstance(data, FramePushed):
|
||||
if on_push_frame_deprecated:
|
||||
await observer.on_push_frame(
|
||||
data.source, data.destination, data.frame, data.direction, data.timestamp
|
||||
|
||||
@@ -104,7 +104,7 @@ class DTMFAggregator(FrameProcessor):
|
||||
|
||||
# For first digit, schedule interruption.
|
||||
if is_first_digit:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
# Check for immediate flush conditions
|
||||
if frame.button == self._termination_digit:
|
||||
|
||||
@@ -6,8 +6,10 @@
|
||||
|
||||
"""This module defines a summarizer for managing LLM context summarization."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -17,28 +19,68 @@ from pipecat.frames.frames import (
|
||||
LLMContextSummaryRequestFrame,
|
||||
LLMContextSummaryResultFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMSummarizeContextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMContextSummarizationConfig,
|
||||
DEFAULT_SUMMARIZATION_TIMEOUT,
|
||||
LLMAutoContextSummarizationConfig,
|
||||
LLMContextSummarizationUtil,
|
||||
LLMContextSummaryConfig,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryAppliedEvent:
|
||||
"""Event data emitted when context summarization completes successfully.
|
||||
|
||||
Parameters:
|
||||
original_message_count: Number of messages before summarization.
|
||||
new_message_count: Number of messages after summarization.
|
||||
summarized_message_count: Number of messages that were compressed
|
||||
into the summary.
|
||||
preserved_message_count: Number of recent messages preserved
|
||||
uncompressed.
|
||||
"""
|
||||
|
||||
original_message_count: int
|
||||
new_message_count: int
|
||||
summarized_message_count: int
|
||||
preserved_message_count: int
|
||||
|
||||
|
||||
class LLMContextSummarizer(BaseObject):
|
||||
"""Summarizer for managing LLM context summarization.
|
||||
|
||||
This class manages automatic context summarization when token or message
|
||||
limits are reached. It monitors the LLM context size, triggers
|
||||
summarization requests, and applies the results to compress conversation history.
|
||||
This class manages context summarization, either automatically when token or
|
||||
message limits are reached, or on-demand when an ``LLMSummarizeContextFrame``
|
||||
is received. It monitors the LLM context size, triggers summarization requests,
|
||||
and applies the results to compress conversation history.
|
||||
|
||||
When ``auto_trigger=True`` (the default), summarization is triggered
|
||||
automatically based on the configured thresholds in
|
||||
``LLMAutoContextSummarizationConfig``. When ``auto_trigger=False``,
|
||||
threshold checks are skipped and summarization only happens when an
|
||||
``LLMSummarizeContextFrame`` is explicitly pushed into the pipeline.
|
||||
|
||||
Both modes can coexist: set ``auto_trigger=True`` and also push
|
||||
``LLMSummarizeContextFrame`` at any time to force an immediate summarization
|
||||
(subject to the ``_summarization_in_progress`` guard).
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_request_summarization: Emitted when summarization should be triggered.
|
||||
The aggregator should broadcast this frame to the LLM service.
|
||||
|
||||
- on_summary_applied: Emitted after a summary has been successfully applied
|
||||
to the context. Receives a SummaryAppliedEvent with metrics about the
|
||||
compression.
|
||||
|
||||
Example::
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
@@ -49,24 +91,36 @@ class LLMContextSummarizer(BaseObject):
|
||||
context=frame.context,
|
||||
...
|
||||
)
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(f"Compressed {event.original_message_count} -> {event.new_message_count} messages")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context: LLMContext,
|
||||
config: Optional[LLMContextSummarizationConfig] = None,
|
||||
config: Optional[LLMAutoContextSummarizationConfig] = None,
|
||||
auto_trigger: bool = True,
|
||||
):
|
||||
"""Initialize the context summarizer.
|
||||
|
||||
Args:
|
||||
context: The LLM context to monitor and summarize.
|
||||
config: Configuration for summarization behavior. If None, uses default config.
|
||||
config: Auto-summarization configuration controlling both trigger
|
||||
thresholds and default summary generation parameters. If None,
|
||||
uses default ``LLMAutoContextSummarizationConfig`` values.
|
||||
auto_trigger: Whether to automatically trigger summarization when
|
||||
thresholds are reached. When False, summarization only happens
|
||||
when an ``LLMSummarizeContextFrame`` is pushed into the pipeline.
|
||||
Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._context = context
|
||||
self._config = config or LLMContextSummarizationConfig()
|
||||
self._auto_config = config or LLMAutoContextSummarizationConfig()
|
||||
self._auto_trigger = auto_trigger
|
||||
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
|
||||
@@ -74,6 +128,7 @@ class LLMContextSummarizer(BaseObject):
|
||||
self._pending_summary_request_id: Optional[str] = None
|
||||
|
||||
self._register_event_handler("on_request_summarization", sync=True)
|
||||
self._register_event_handler("on_summary_applied")
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BaseTaskManager:
|
||||
@@ -103,6 +158,8 @@ class LLMContextSummarizer(BaseObject):
|
||||
"""
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_response_start(frame)
|
||||
elif isinstance(frame, LLMSummarizeContextFrame):
|
||||
await self._handle_manual_summarization_request(frame)
|
||||
elif isinstance(frame, LLMContextSummaryResultFrame):
|
||||
await self._handle_summary_result(frame)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
@@ -117,12 +174,24 @@ class LLMContextSummarizer(BaseObject):
|
||||
if self._should_summarize():
|
||||
await self._request_summarization()
|
||||
|
||||
async def _handle_interruption(self):
|
||||
"""Handle interruption by canceling summarization in progress.
|
||||
async def _handle_manual_summarization_request(self, frame: LLMSummarizeContextFrame):
|
||||
"""Handle an explicit on-demand summarization request.
|
||||
|
||||
Reuses the same ``_request_summarization()`` code path as auto mode,
|
||||
so bookkeeping (``_summarization_in_progress``,
|
||||
``_pending_summary_request_id``) is always updated correctly.
|
||||
|
||||
Args:
|
||||
frame: The interruption frame.
|
||||
frame: The manual summarization request frame, optionally carrying
|
||||
a per-request :class:`~pipecat.utils.context.llm_context_summarization.LLMContextSummaryConfig`.
|
||||
"""
|
||||
if self._summarization_in_progress:
|
||||
logger.debug(f"{self}: Summarization already in progress, ignoring manual request")
|
||||
return
|
||||
await self._request_summarization(config_override=frame.config)
|
||||
|
||||
async def _handle_interruption(self):
|
||||
"""Handle interruption by canceling summarization in progress."""
|
||||
# Reset summarization state to allow new requests. This is necessary because
|
||||
# the request frame (LLMContextSummaryRequestFrame) may have been cancelled
|
||||
# during interruption. We preserve _pending_summary_request_id to handle the
|
||||
@@ -145,13 +214,17 @@ class LLMContextSummarizer(BaseObject):
|
||||
|
||||
Returns:
|
||||
True if all conditions are met:
|
||||
- ``auto_trigger`` is enabled
|
||||
- No summarization currently in progress
|
||||
- AND either:
|
||||
- Token count exceeds max_context_tokens
|
||||
- OR message count exceeds max_unsummarized_messages since last summary
|
||||
- Token count exceeds ``max_context_tokens``
|
||||
- OR message count exceeds ``max_unsummarized_messages`` since last summary
|
||||
"""
|
||||
logger.trace(f"{self}: Checking if context summarization is needed")
|
||||
|
||||
if not self._auto_trigger:
|
||||
return False
|
||||
|
||||
if self._summarization_in_progress:
|
||||
logger.debug(f"{self}: Summarization already in progress")
|
||||
return False
|
||||
@@ -161,20 +234,20 @@ class LLMContextSummarizer(BaseObject):
|
||||
num_messages = len(self._context.messages)
|
||||
|
||||
# Check if we've reached the token limit
|
||||
token_limit = self._config.max_context_tokens
|
||||
token_limit = self._auto_config.max_context_tokens
|
||||
token_limit_exceeded = total_tokens >= token_limit
|
||||
|
||||
# Check if we've exceeded max unsummarized messages
|
||||
messages_since_summary = len(self._context.messages) - 1
|
||||
message_threshold_exceeded = (
|
||||
messages_since_summary >= self._config.max_unsummarized_messages
|
||||
messages_since_summary >= self._auto_config.max_unsummarized_messages
|
||||
)
|
||||
|
||||
logger.trace(
|
||||
f"{self}: Context has {num_messages} messages, "
|
||||
f"~{total_tokens} tokens (limit: {token_limit}), "
|
||||
f"{messages_since_summary} messages since last summary "
|
||||
f"(message threshold: {self._config.max_unsummarized_messages})"
|
||||
f"(message threshold: {self._auto_config.max_unsummarized_messages})"
|
||||
)
|
||||
|
||||
# Trigger if either limit is exceeded
|
||||
@@ -189,21 +262,30 @@ class LLMContextSummarizer(BaseObject):
|
||||
reason.append(f"~{total_tokens} tokens (>={token_limit} limit)")
|
||||
if message_threshold_exceeded:
|
||||
reason.append(
|
||||
f"{messages_since_summary} messages (>={self._config.max_unsummarized_messages} threshold)"
|
||||
f"{messages_since_summary} messages (>={self._auto_config.max_unsummarized_messages} threshold)"
|
||||
)
|
||||
|
||||
logger.debug(f"{self}: ✓ Summarization needed - {', '.join(reason)}")
|
||||
return True
|
||||
|
||||
async def _request_summarization(self):
|
||||
async def _request_summarization(
|
||||
self, config_override: Optional[LLMContextSummaryConfig] = None
|
||||
):
|
||||
"""Request context summarization from LLM service.
|
||||
|
||||
Creates a summarization request frame and emits it via event handler.
|
||||
Creates a summarization request frame and either handles it directly
|
||||
using a dedicated LLM (if configured) or emits it via event handler
|
||||
for the pipeline's primary LLM.
|
||||
Tracks the request ID to match async responses and prevent race conditions.
|
||||
|
||||
Args:
|
||||
config_override: Optional per-request summary configuration. If provided,
|
||||
overrides the default summary generation settings from
|
||||
``self._auto_config.summary_config``.
|
||||
"""
|
||||
# Generate unique request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
min_keep = self._config.min_messages_after_summary
|
||||
summary_config = config_override or self._auto_config.summary_config
|
||||
|
||||
# Mark summarization in progress
|
||||
self._summarization_in_progress = True
|
||||
@@ -215,13 +297,66 @@ class LLMContextSummarizer(BaseObject):
|
||||
request_frame = LLMContextSummaryRequestFrame(
|
||||
request_id=request_id,
|
||||
context=self._context,
|
||||
min_messages_to_keep=min_keep,
|
||||
target_context_tokens=self._config.target_context_tokens,
|
||||
summarization_prompt=self._config.summary_prompt,
|
||||
min_messages_to_keep=summary_config.min_messages_after_summary,
|
||||
target_context_tokens=summary_config.target_context_tokens,
|
||||
summarization_prompt=summary_config.summary_prompt,
|
||||
summarization_timeout=summary_config.summarization_timeout,
|
||||
)
|
||||
|
||||
# Emit event for aggregator to broadcast
|
||||
await self._call_event_handler("on_request_summarization", request_frame)
|
||||
if summary_config.llm:
|
||||
# Use dedicated LLM directly — no need to involve the pipeline
|
||||
self.task_manager.create_task(
|
||||
self._generate_summary_with_dedicated_llm(summary_config.llm, request_frame),
|
||||
f"{self}-dedicated-llm-summary",
|
||||
)
|
||||
else:
|
||||
# Emit event for aggregator to broadcast to the pipeline LLM
|
||||
await self._call_event_handler("on_request_summarization", request_frame)
|
||||
|
||||
async def _generate_summary_with_dedicated_llm(
|
||||
self, llm: "LLMService", frame: LLMContextSummaryRequestFrame
|
||||
):
|
||||
"""Generate summary using a dedicated LLM service.
|
||||
|
||||
Calls the dedicated LLM's _generate_summary directly and feeds the
|
||||
result back through _handle_summary_result, bypassing the pipeline.
|
||||
|
||||
Args:
|
||||
llm: The dedicated LLM service to use for summarization.
|
||||
frame: The summarization request frame.
|
||||
"""
|
||||
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
|
||||
|
||||
try:
|
||||
summary, last_index = await asyncio.wait_for(
|
||||
llm._generate_summary(frame),
|
||||
timeout=timeout,
|
||||
)
|
||||
result_frame = LLMContextSummaryResultFrame(
|
||||
request_id=frame.request_id,
|
||||
summary=summary,
|
||||
last_summarized_index=last_index,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
error = f"Context summarization timed out after {timeout}s"
|
||||
logger.error(f"{self}: {error}")
|
||||
result_frame = LLMContextSummaryResultFrame(
|
||||
request_id=frame.request_id,
|
||||
summary="",
|
||||
last_summarized_index=-1,
|
||||
error=error,
|
||||
)
|
||||
except Exception as e:
|
||||
error = f"Error generating context summary: {e}"
|
||||
logger.error(f"{self}: {error}")
|
||||
result_frame = LLMContextSummaryResultFrame(
|
||||
request_id=frame.request_id,
|
||||
summary="",
|
||||
last_summarized_index=-1,
|
||||
error=error,
|
||||
)
|
||||
|
||||
await self._handle_summary_result(result_frame)
|
||||
|
||||
async def _handle_summary_result(self, frame: LLMContextSummaryResultFrame):
|
||||
"""Handle context summarization result from LLM service.
|
||||
@@ -234,7 +369,9 @@ class LLMContextSummarizer(BaseObject):
|
||||
"""
|
||||
logger.debug(f"{self}: Received summary result (request_id={frame.request_id})")
|
||||
|
||||
# Check if this is the result we're waiting for
|
||||
# Check if this is the result we're waiting for. Both auto and manual
|
||||
# summarization set _pending_summary_request_id via _request_summarization(),
|
||||
# so this check always applies.
|
||||
if frame.request_id != self._pending_summary_request_id:
|
||||
logger.debug(f"{self}: Ignoring stale summary result (request_id={frame.request_id})")
|
||||
return
|
||||
@@ -271,7 +408,7 @@ class LLMContextSummarizer(BaseObject):
|
||||
if last_summarized_index >= len(self._context.messages):
|
||||
return False
|
||||
|
||||
min_keep = self._config.min_messages_after_summary
|
||||
min_keep = self._auto_config.summary_config.min_messages_after_summary
|
||||
remaining = len(self._context.messages) - 1 - last_summarized_index
|
||||
if remaining < min_keep:
|
||||
return False
|
||||
@@ -288,16 +425,29 @@ class LLMContextSummarizer(BaseObject):
|
||||
summary: The generated summary text.
|
||||
last_summarized_index: Index of the last message that was summarized.
|
||||
"""
|
||||
config = self._auto_config.summary_config
|
||||
messages = self._context.messages
|
||||
|
||||
# Find the first system message to preserve
|
||||
first_system_msg = next((msg for msg in messages if msg.get("role") == "system"), None)
|
||||
# Find the first system message to preserve. LLMSpecificMessage instances are excluded
|
||||
# because they are not dict-like and never represent a system message; they hold
|
||||
# service-specific metadata (e.g. thinking blocks) that is always paired with a
|
||||
# standard message.
|
||||
first_system_msg = next(
|
||||
(
|
||||
msg
|
||||
for msg in messages
|
||||
if not isinstance(msg, LLMSpecificMessage) and msg.get("role") == "system"
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Get recent messages to keep
|
||||
recent_messages = messages[last_summarized_index + 1 :]
|
||||
|
||||
# Create summary message as an assistant message
|
||||
summary_message = {"role": "assistant", "content": f"Conversation summary: {summary}"}
|
||||
# Create summary message as a user message (the summary is context
|
||||
# provided *to* the assistant, not something the assistant said)
|
||||
summary_content = config.summary_message_template.format(summary=summary)
|
||||
summary_message = {"role": "user", "content": summary_content}
|
||||
|
||||
# Reconstruct context
|
||||
new_messages = []
|
||||
@@ -307,9 +457,23 @@ class LLMContextSummarizer(BaseObject):
|
||||
new_messages.extend(recent_messages)
|
||||
|
||||
# Update context
|
||||
original_message_count = len(messages)
|
||||
num_system_preserved = 1 if first_system_msg else 0
|
||||
self._context.set_messages(new_messages)
|
||||
|
||||
# Messages actually summarized = index range minus the preserved system message
|
||||
summarized_count = last_summarized_index + 1 - num_system_preserved
|
||||
|
||||
logger.info(
|
||||
f"{self}: Applied context summary, compressed {last_summarized_index + 1} messages "
|
||||
f"into summary. Context now has {len(new_messages)} messages (was {len(messages)})"
|
||||
f"{self}: Applied context summary, compressed {summarized_count} messages "
|
||||
f"into summary. Context now has {len(new_messages)} messages (was {original_message_count})"
|
||||
)
|
||||
|
||||
# Emit event for observability
|
||||
event = SummaryAppliedEvent(
|
||||
original_message_count=original_message_count,
|
||||
new_message_count=len(new_messages),
|
||||
summarized_message_count=summarized_count,
|
||||
preserved_message_count=len(recent_messages) + num_system_preserved,
|
||||
)
|
||||
await self._call_event_handler("on_summary_applied", event)
|
||||
|
||||
@@ -581,7 +581,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
|
||||
@@ -35,6 +35,7 @@ from pipecat.frames.frames import (
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMAssistantPushAggregationFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMContextSummaryRequestFrame,
|
||||
@@ -78,7 +79,10 @@ from pipecat.turns.user_stop import BaseUserTurnStopStrategy, UserTurnStoppedPar
|
||||
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionConfig
|
||||
from pipecat.turns.user_turn_controller import UserTurnController
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies, UserTurnStrategies
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMAutoContextSummarizationConfig,
|
||||
LLMContextSummarizationConfig,
|
||||
)
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -124,18 +128,54 @@ class LLMAssistantAggregatorParams:
|
||||
in text frames by adding spaces between tokens. This parameter is
|
||||
ignored when used with the newer LLMAssistantAggregator, which
|
||||
handles word spacing automatically.
|
||||
enable_context_summarization: Enable automatic context summarization when token
|
||||
limits are reached (disabled by default). When enabled, older conversation
|
||||
messages are automatically compressed into summaries to manage context size.
|
||||
context_summarization_config: Configuration for context summarization behavior.
|
||||
Controls thresholds, message preservation, and summarization prompts. If None
|
||||
and summarization is enabled, uses default configuration values.
|
||||
enable_auto_context_summarization: Enable automatic context summarization when token
|
||||
or message-count limits are reached (disabled by default). When enabled,
|
||||
older conversation messages are automatically compressed into summaries to
|
||||
manage context size.
|
||||
auto_context_summarization_config: Configuration for automatic context
|
||||
summarization. Controls trigger thresholds, message preservation, and
|
||||
summarization prompts. If None, uses default
|
||||
``LLMAutoContextSummarizationConfig`` values.
|
||||
"""
|
||||
|
||||
expect_stripped_words: bool = True
|
||||
enable_context_summarization: bool = False
|
||||
enable_auto_context_summarization: bool = False
|
||||
auto_context_summarization_config: Optional[LLMAutoContextSummarizationConfig] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deprecated field names — kept for backward compatibility.
|
||||
# Use enable_auto_context_summarization and auto_context_summarization_config instead.
|
||||
# ---------------------------------------------------------------------------
|
||||
enable_context_summarization: Optional[bool] = None
|
||||
context_summarization_config: Optional[LLMContextSummarizationConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.enable_context_summarization is not None:
|
||||
warnings.warn(
|
||||
"LLMAssistantAggregatorParams.enable_context_summarization is deprecated. "
|
||||
"Use enable_auto_context_summarization instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.enable_auto_context_summarization = self.enable_context_summarization
|
||||
self.enable_context_summarization = None
|
||||
|
||||
if self.context_summarization_config is not None:
|
||||
warnings.warn(
|
||||
"LLMAssistantAggregatorParams.context_summarization_config is deprecated. "
|
||||
"Use auto_context_summarization_config (LLMAutoContextSummarizationConfig) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(self.context_summarization_config, LLMContextSummarizationConfig):
|
||||
self.auto_context_summarization_config = (
|
||||
self.context_summarization_config.to_auto_config()
|
||||
)
|
||||
else:
|
||||
# Accept LLMAutoContextSummarizationConfig passed to the deprecated field
|
||||
self.auto_context_summarization_config = self.context_summarization_config # type: ignore[assignment]
|
||||
self.context_summarization_config = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserTurnStoppedMessage:
|
||||
@@ -568,12 +608,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if should_mute_frame:
|
||||
logger.trace(f"{frame.name} suppressed - user currently muted")
|
||||
|
||||
# When muted, the InterruptionFrame won't propagate further and
|
||||
# will never reach the pipeline sink. Complete it here so
|
||||
# push_interruption_task_frame_and_wait() doesn't hang.
|
||||
if should_mute_frame and isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
|
||||
should_mute_next_time = False
|
||||
for s in self._params.user_mute_strategies:
|
||||
should_mute_next_time |= await s.process_frame(frame)
|
||||
@@ -602,6 +636,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
if self._params.filter_incomplete_user_turns:
|
||||
config = self._params.user_turn_completion_config or UserTurnCompletionConfig()
|
||||
self._context.add_message({"role": "system", "content": config.completion_instructions})
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
@@ -694,7 +731,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
@@ -824,16 +861,18 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._thought_aggregation: List[TextPartForConcatenation] = []
|
||||
self._thought_start_time: str = ""
|
||||
|
||||
# Context summarization
|
||||
self._summarizer: Optional[LLMContextSummarizer] = None
|
||||
if self._params.enable_context_summarization:
|
||||
self._summarizer = LLMContextSummarizer(
|
||||
context=self._context,
|
||||
config=self._params.context_summarization_config,
|
||||
)
|
||||
self._summarizer.add_event_handler(
|
||||
"on_request_summarization", self._on_request_summarization
|
||||
)
|
||||
# Context summarization — always create the summarizer so that manually
|
||||
# pushed LLMSummarizeContextFrame frames are always handled.
|
||||
# Auto-triggering based on thresholds is only enabled when
|
||||
# enable_auto_context_summarization is True.
|
||||
self._summarizer: Optional[LLMContextSummarizer] = LLMContextSummarizer(
|
||||
context=self._context,
|
||||
config=self._params.auto_context_summarization_config,
|
||||
auto_trigger=self._params.enable_auto_context_summarization,
|
||||
)
|
||||
self._summarizer.add_event_handler(
|
||||
"on_request_summarization", self._on_request_summarization
|
||||
)
|
||||
|
||||
self._register_event_handler("on_assistant_turn_started")
|
||||
self._register_event_handler("on_assistant_turn_stopped")
|
||||
@@ -879,6 +918,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._handle_end_or_cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMAssistantPushAggregationFrame):
|
||||
await self.push_aggregation()
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
|
||||
@@ -234,12 +234,6 @@ class STTMuteFilter(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
logger.trace(f"{frame.__class__.__name__} suppressed - STT currently muted")
|
||||
|
||||
# When muted, the InterruptionFrame won't propagate further
|
||||
# and will never reach the pipeline sink. Complete it here so
|
||||
# push_interruption_task_frame_and_wait() doesn't hang.
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
else:
|
||||
# Pass all other frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -41,7 +41,6 @@ from pipecat.frames.frames import (
|
||||
FrameProcessorResumeFrame,
|
||||
FrameProcessorResumeUrgentFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
UninterruptibleFrame,
|
||||
@@ -240,10 +239,6 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_current_frame: Optional[Frame] = None
|
||||
|
||||
# Set while awaiting push_interruption_task_frame_and_wait() so that
|
||||
# _start_interruption() knows not to cancel the process task.
|
||||
self._wait_for_interruption = False
|
||||
|
||||
# Frame processor events.
|
||||
self._register_event_handler("on_before_process_frame", sync=True)
|
||||
self._register_event_handler("on_after_process_frame", sync=True)
|
||||
@@ -329,7 +324,7 @@ class FrameProcessor(BaseObject):
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FrameProcessor.interruptions_allowed` is deprecated. "
|
||||
"Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.",
|
||||
"Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -631,15 +626,6 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption, bypass all queued system frames
|
||||
# and process the frame right away. This is because a previous system
|
||||
# frame might be waiting for the interruption frame blocking the input
|
||||
# task, so this InterruptionFrame would never be dequeued and we'd
|
||||
# deadlock.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
|
||||
if self._enable_direct_mode:
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
@@ -774,43 +760,32 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_push_frame", frame)
|
||||
|
||||
async def broadcast_interruption(self):
|
||||
"""Broadcast an `InterruptionFrame` both upstream and downstream."""
|
||||
logger.debug(f"{self}: broadcasting interruption")
|
||||
self.__reset_process_task()
|
||||
await self.stop_all_metrics()
|
||||
await self.broadcast_frame(InterruptionFrame)
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self, *, timeout: float = 5.0):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the
|
||||
pipeline task. The task creates a corresponding `InterruptionFrame`
|
||||
and sends it downstream through the pipeline. An `asyncio.Event` is
|
||||
attached to both frames so the caller can wait until the interruption
|
||||
has fully traversed the pipeline. The event is set when the
|
||||
`InterruptionFrame` reaches the pipeline sink. If the frame does
|
||||
not complete within the given timeout, a warning is logged and the
|
||||
event is forcibly set so the caller is unblocked.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for the interruption to complete.
|
||||
.. deprecated:: 0.0.104
|
||||
Use :meth:`broadcast_interruption` instead. This method now
|
||||
delegates to ``broadcast_interruption()`` and ignores *timeout*.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
import warnings
|
||||
|
||||
event = asyncio.Event()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FrameProcessor.push_interruption_task_frame_and_wait()` is deprecated. "
|
||||
"Use `FrameProcessor.broadcast_interruption()` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(event=event), FrameDirection.UPSTREAM)
|
||||
|
||||
# Wait for the `InterruptionFrame` to complete and log a warning if it
|
||||
# takes too long. If it does take too long make sure we unblock it,
|
||||
# otherwise we will hang here forever.
|
||||
while not event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"{self}: InterruptionFrame has not completed after"
|
||||
f" {timeout}s. Make sure InterruptionFrame.complete()"
|
||||
" is being called (e.g. if the frame is being blocked"
|
||||
" or consumed before reaching the pipeline sink)."
|
||||
)
|
||||
event.set()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def broadcast_frame(self, frame_cls: Type[Frame], **kwargs):
|
||||
"""Broadcasts a frame of the specified class upstream and downstream.
|
||||
@@ -917,15 +892,7 @@ class FrameProcessor(BaseObject):
|
||||
async def _start_interruption(self):
|
||||
"""Start handling an interruption by cancelling current tasks."""
|
||||
try:
|
||||
if self._wait_for_interruption:
|
||||
# If we get here we know the process task was just waiting for
|
||||
# an interruption (push_interruption_task_frame_and_wait()), so
|
||||
# we can't cancel the task because it might still need to do
|
||||
# more things (e.g. pushing a frame after the
|
||||
# interruption). Instead we just drain the queue because this is
|
||||
# an interruption.
|
||||
self.__reset_process_task()
|
||||
elif isinstance(self.__process_current_frame, UninterruptibleFrame):
|
||||
if isinstance(self.__process_current_frame, UninterruptibleFrame):
|
||||
# We don't want to cancel UninterruptibleFrame, so we simply
|
||||
# cleanup the queue.
|
||||
self.__reset_process_queue()
|
||||
@@ -949,7 +916,7 @@ class FrameProcessor(BaseObject):
|
||||
try:
|
||||
timestamp = self._clock.get_time() if self._clock else 0
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
logger.trace(f"Pushing {frame} downstream from {self} to {self._next}")
|
||||
|
||||
if self._observer:
|
||||
data = FramePushed(
|
||||
|
||||
@@ -1702,7 +1702,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def interrupt_bot(self):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def send_server_message(self, data: Any):
|
||||
"""Send a server message to the client."""
|
||||
|
||||
@@ -642,7 +642,6 @@ class GenesysAudioHookSerializer(FrameSerializer):
|
||||
"""
|
||||
# Binary data = audio
|
||||
if isinstance(data, bytes):
|
||||
logger.debug(f"[AUDIO IN] Received {len(data)} bytes from Genesys")
|
||||
return await self._deserialize_audio(data)
|
||||
|
||||
# Text data = JSON control message
|
||||
|
||||
@@ -12,7 +12,8 @@ transcription WebSocket messages and connection configuration.
|
||||
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class Word(BaseModel):
|
||||
@@ -68,8 +69,16 @@ class TurnMessage(BaseMessage):
|
||||
transcript: The transcribed text for this turn.
|
||||
end_of_turn_confidence: Confidence score for end-of-turn detection.
|
||||
words: List of individual words with timing and confidence data.
|
||||
language_code: Detected language code (e.g., "es", "fr"). Only present with
|
||||
complete utterances or when end_of_turn is True.
|
||||
language_confidence: Confidence score (0-1) for language detection. Only present
|
||||
with complete utterances or when end_of_turn is True.
|
||||
speaker: Speaker label (e.g., "A", "B"). Only present when speaker_labels is
|
||||
enabled and end_of_turn is True. Maps to 'speaker_label' in JSON response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
type: Literal["Turn"] = "Turn"
|
||||
turn_order: int
|
||||
turn_is_formatted: bool
|
||||
@@ -77,6 +86,21 @@ class TurnMessage(BaseMessage):
|
||||
transcript: str
|
||||
end_of_turn_confidence: float
|
||||
words: List[Word]
|
||||
language_code: Optional[str] = None
|
||||
language_confidence: Optional[float] = None
|
||||
speaker: Optional[str] = Field(default=None, alias="speaker_label")
|
||||
|
||||
|
||||
class SpeechStartedMessage(BaseMessage):
|
||||
"""Message sent when speech is first detected in the audio stream.
|
||||
|
||||
Parameters:
|
||||
type: Always "SpeechStarted" for this message type.
|
||||
timestamp: Audio timestamp in milliseconds when speech was detected.
|
||||
"""
|
||||
|
||||
type: Literal["SpeechStarted"] = "SpeechStarted"
|
||||
timestamp: int
|
||||
|
||||
|
||||
class TerminationMessage(BaseMessage):
|
||||
@@ -94,7 +118,7 @@ class TerminationMessage(BaseMessage):
|
||||
|
||||
|
||||
# Union type for all possible message types
|
||||
AnyMessage = BeginMessage | TurnMessage | TerminationMessage
|
||||
AnyMessage = BeginMessage | TurnMessage | SpeechStartedMessage | TerminationMessage
|
||||
|
||||
|
||||
class AssemblyAIConnectionParams(BaseModel):
|
||||
@@ -106,10 +130,19 @@ class AssemblyAIConnectionParams(BaseModel):
|
||||
formatted_finals: Whether to enable transcript formatting. Defaults to True.
|
||||
word_finalization_max_wait_time: Maximum time to wait for word finalization in milliseconds.
|
||||
end_of_turn_confidence_threshold: Confidence threshold for end-of-turn detection.
|
||||
min_end_of_turn_silence_when_confident: Minimum silence duration when confident about end-of-turn.
|
||||
min_turn_silence: Minimum silence duration when confident about end-of-turn.
|
||||
min_end_of_turn_silence_when_confident: DEPRECATED. Use min_turn_silence instead.
|
||||
max_turn_silence: Maximum silence duration before forcing end-of-turn.
|
||||
keyterms_prompt: List of key terms to guide transcription. Will be JSON serialized before sending.
|
||||
speech_model: Select between English and multilingual models. Defaults to "universal-streaming-english".
|
||||
prompt: Optional text prompt to guide the transcription. Only used when speech_model is "u3-rt-pro".
|
||||
speech_model: Select between English, multilingual, and u3-rt-pro models. Defaults to "u3-rt-pro".
|
||||
language_detection: Enable automatic language detection. Only applicable to
|
||||
universal-streaming-multilingual. When enabled, Turn messages include
|
||||
language_code and language_confidence fields. Defaults to None (not sent).
|
||||
format_turns: Whether to format transcript turns. Defaults to True.
|
||||
speaker_labels: Enable speaker diarization. When enabled, final transcripts
|
||||
(end_of_turn=True) include a speaker field identifying the speaker
|
||||
(e.g., "Speaker A", "Speaker B"). Defaults to None (not sent).
|
||||
"""
|
||||
|
||||
sample_rate: int = 16000
|
||||
@@ -117,9 +150,27 @@ class AssemblyAIConnectionParams(BaseModel):
|
||||
formatted_finals: bool = True
|
||||
word_finalization_max_wait_time: Optional[int] = None
|
||||
end_of_turn_confidence_threshold: Optional[float] = None
|
||||
min_end_of_turn_silence_when_confident: Optional[int] = None
|
||||
min_turn_silence: Optional[int] = None
|
||||
min_end_of_turn_silence_when_confident: Optional[int] = None # Deprecated
|
||||
max_turn_silence: Optional[int] = None
|
||||
keyterms_prompt: Optional[List[str]] = None
|
||||
speech_model: Literal["universal-streaming-english", "universal-streaming-multilingual"] = (
|
||||
"universal-streaming-english"
|
||||
)
|
||||
prompt: Optional[str] = None
|
||||
speech_model: Literal[
|
||||
"universal-streaming-english", "universal-streaming-multilingual", "u3-rt-pro"
|
||||
] = "u3-rt-pro"
|
||||
language_detection: Optional[bool] = None
|
||||
format_turns: bool = True
|
||||
speaker_labels: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def handle_deprecated_param(self):
|
||||
"""Handle deprecated min_end_of_turn_silence_when_confident parameter."""
|
||||
if self.min_end_of_turn_silence_when_confident is not None:
|
||||
logger.warning(
|
||||
"The 'min_end_of_turn_silence_when_confident' parameter is deprecated and will be "
|
||||
"removed in a future version. Please use 'min_turn_silence' instead."
|
||||
)
|
||||
# If min_turn_silence is not set, use the deprecated value
|
||||
if self.min_turn_silence is None:
|
||||
self.min_turn_silence = self.min_end_of_turn_silence_when_confident
|
||||
return self
|
||||
|
||||
@@ -26,6 +26,8 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -41,6 +43,7 @@ from .models import (
|
||||
AssemblyAIConnectionParams,
|
||||
BaseMessage,
|
||||
BeginMessage,
|
||||
SpeechStartedMessage,
|
||||
TerminationMessage,
|
||||
TurnMessage,
|
||||
)
|
||||
@@ -54,6 +57,28 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def map_language_from_assemblyai(language_code: str) -> Language:
|
||||
"""Map AssemblyAI language codes to Pipecat Language enum.
|
||||
|
||||
AssemblyAI returns simple language codes like "es", "fr", etc.
|
||||
This function maps them to the corresponding Language enum values.
|
||||
|
||||
Args:
|
||||
language_code: AssemblyAI language code (e.g., "es", "fr", "de")
|
||||
|
||||
Returns:
|
||||
Corresponding Language enum value, defaulting to Language.EN if not found.
|
||||
"""
|
||||
try:
|
||||
# Try to match the language code directly
|
||||
return Language(language_code.lower())
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Unknown language code from AssemblyAI: {language_code}, defaulting to English"
|
||||
)
|
||||
return Language.EN
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyAISTTSettings(STTSettings):
|
||||
"""Settings for the AssemblyAI STT service.
|
||||
@@ -87,6 +112,8 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
api_endpoint_base_url: str = "wss://streaming.assemblyai.com/v3/ws",
|
||||
connection_params: AssemblyAIConnectionParams = AssemblyAIConnectionParams(),
|
||||
vad_force_turn_endpoint: bool = True,
|
||||
should_interrupt: bool = True,
|
||||
speaker_format: Optional[str] = None,
|
||||
ttfs_p99_latency: Optional[float] = ASSEMBLYAI_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -97,18 +124,66 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
language: Language code for transcription. Defaults to English (Language.EN).
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to AssemblyAI's streaming endpoint.
|
||||
connection_params: Connection configuration parameters. Defaults to AssemblyAIConnectionParams().
|
||||
vad_force_turn_endpoint: Whether to force turn endpoint on VAD stop. When True,
|
||||
disables AssemblyAI's model-based turn detection and relies on external VAD
|
||||
to trigger turn endpoints. Automatically sets end_of_turn_confidence_threshold=1.0
|
||||
and max_turn_silence=2000 unless explicitly overridden. Defaults to True.
|
||||
vad_force_turn_endpoint: Controls turn detection mode.
|
||||
When True (Pipecat mode, default): Forces AssemblyAI to return finals ASAP
|
||||
so Pipecat's turn detection (e.g., Smart Turn) decides when the user is done.
|
||||
- min_turn_silence defaults to 100ms (user can override)
|
||||
- max_turn_silence is ALWAYS set equal to min_turn_silence
|
||||
- VAD stop sends ForceEndpoint as ceiling
|
||||
- No UserStarted/StoppedSpeakingFrame emitted from STT
|
||||
When False (AssemblyAI turn detection mode, u3-rt-pro only): AssemblyAI's model
|
||||
controls turn endings using built-in turn detection.
|
||||
- Uses AssemblyAI API defaults for all parameters (unless user explicitly sets them)
|
||||
- Respects all user-provided connection_params as-is
|
||||
- Emits UserStarted/StoppedSpeakingFrame from STT
|
||||
- No ForceEndpoint on VAD stop
|
||||
should_interrupt: Whether to interrupt the bot when the user starts speaking
|
||||
in AssemblyAI turn detection mode (vad_force_turn_endpoint=False). Only applies
|
||||
when using AssemblyAI's built-in turn detection. Defaults to True.
|
||||
speaker_format: Optional format string for speaker labels when diarization is enabled.
|
||||
Use {speaker} for speaker label and {text} for transcript text.
|
||||
Example: "<{speaker}>{text}</{speaker}>" or "{speaker}: {text}"
|
||||
If None, transcript text is not modified. Defaults to None.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
# When vad_force_turn_endpoint is enabled, configure connection params for manual
|
||||
# turn detection mode (disable model-based turn detection)
|
||||
# AssemblyAI turn detection mode (vad_force_turn_endpoint=False) requires the
|
||||
# SpeechStarted event for reliable barge-in. Only u3-rt-pro supports
|
||||
# this. Other models must use Pipecat turn detection.
|
||||
is_u3_pro = connection_params.speech_model == "u3-rt-pro"
|
||||
if not vad_force_turn_endpoint and not is_u3_pro:
|
||||
raise ValueError(
|
||||
f"AssemblyAI turn detection mode (vad_force_turn_endpoint=False) requires "
|
||||
f"u3-rt-pro for SpeechStarted support. Either set "
|
||||
f"vad_force_turn_endpoint=True for {connection_params.speech_model}, "
|
||||
f"or use speech_model='u3-rt-pro'."
|
||||
)
|
||||
|
||||
# Validate that prompt and keyterms_prompt are not both set
|
||||
if connection_params.prompt is not None and connection_params.keyterms_prompt is not None:
|
||||
raise ValueError(
|
||||
"The prompt and keyterms_prompt parameters cannot be used in the same request. "
|
||||
"Please choose either one or the other based on your use case. When you use "
|
||||
"keyterms_prompt, your boosted words are appended to the default prompt automatically. "
|
||||
"Or to boost within prompt: <prompt> + Make sure to boost the words <keyterms> in the audio. "
|
||||
"For more info go to: https://www.assemblyai.com/docs/streaming/universal-3-pro"
|
||||
)
|
||||
|
||||
# Warn if user sets a custom prompt (recommend testing without one first)
|
||||
if connection_params.prompt is not None:
|
||||
logger.warning(
|
||||
"Custom prompt detected. Prompting is a beta feature. We recommend testing "
|
||||
"with no prompt first, as this will use our optimized default prompt for "
|
||||
"voice agents. Bad prompts may lead to bad results. If you'd like to create "
|
||||
"your own prompt, check out our prompting guide at: "
|
||||
"https://www.assemblyai.com/docs/streaming/prompting"
|
||||
)
|
||||
|
||||
# When vad_force_turn_endpoint is enabled, configure connection params
|
||||
# for Pipecat turn detection mode (fast finals for smart turn analyzer)
|
||||
if vad_force_turn_endpoint:
|
||||
connection_params = self._configure_manual_turn_mode(connection_params)
|
||||
connection_params = self._configure_pipecat_turn_mode(connection_params, is_u3_pro)
|
||||
|
||||
super().__init__(
|
||||
sample_rate=connection_params.sample_rate,
|
||||
@@ -124,6 +199,8 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._vad_force_turn_endpoint = vad_force_turn_endpoint
|
||||
self._should_interrupt = should_interrupt
|
||||
self._speaker_format = speaker_format
|
||||
|
||||
self._termination_event = asyncio.Event()
|
||||
self._received_termination = False
|
||||
@@ -135,45 +212,64 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
self._chunk_size_ms = 50
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
def _configure_manual_turn_mode(
|
||||
self, connection_params: AssemblyAIConnectionParams
|
||||
) -> AssemblyAIConnectionParams:
|
||||
"""Configure connection params for manual turn detection mode.
|
||||
self._user_speaking = False
|
||||
|
||||
When vad_force_turn_endpoint is enabled, we want to disable AssemblyAI's
|
||||
model-based turn detection and rely on external VAD. This requires:
|
||||
- end_of_turn_confidence_threshold=1.0 (disable semantic turn detection)
|
||||
- max_turn_silence=2000 (high value since VAD handles turn endings)
|
||||
def _configure_pipecat_turn_mode(
|
||||
self, connection_params: AssemblyAIConnectionParams, is_u3_pro: bool
|
||||
) -> AssemblyAIConnectionParams:
|
||||
"""Configure connection params for Pipecat turn detection mode.
|
||||
|
||||
When vad_force_turn_endpoint is enabled, force AssemblyAI to return
|
||||
finals as fast as possible so Pipecat's smart turn analyzer can decide
|
||||
when the user is done speaking. VAD stop is the absolute ceiling.
|
||||
|
||||
u3-rt-pro:
|
||||
- min_turn_silence defaults to 100ms (user can override)
|
||||
- max_turn_silence is ALWAYS set equal to min_turn_silence
|
||||
to avoid double turn detection (AssemblyAI + Pipecat both analyzing)
|
||||
- If user sets max_turn_silence, it's ignored with a warning
|
||||
- end_of_turn_confidence_threshold: not set (API default)
|
||||
|
||||
universal-streaming-*:
|
||||
- end_of_turn_confidence_threshold=0.0 (disable semantic turn detection)
|
||||
- min_turn_silence=160
|
||||
- max_turn_silence: not set (API default)
|
||||
|
||||
Args:
|
||||
connection_params: The user-provided connection parameters.
|
||||
is_u3_pro: Whether using u3-rt-pro model.
|
||||
|
||||
Returns:
|
||||
Updated connection parameters configured for manual turn mode.
|
||||
Updated connection parameters configured for Pipecat turn mode.
|
||||
"""
|
||||
updates = {}
|
||||
|
||||
# Check end_of_turn_confidence_threshold
|
||||
if connection_params.end_of_turn_confidence_threshold is None:
|
||||
updates["end_of_turn_confidence_threshold"] = 1.0
|
||||
elif connection_params.end_of_turn_confidence_threshold != 1.0:
|
||||
logger.warning(
|
||||
f"vad_force_turn_endpoint is enabled but end_of_turn_confidence_threshold "
|
||||
f"is set to {connection_params.end_of_turn_confidence_threshold}. "
|
||||
f"For manual turn detection mode, this should be 1.0 to disable "
|
||||
f"model-based turn detection. The current value will be used."
|
||||
)
|
||||
if is_u3_pro:
|
||||
# u3-rt-pro: Synchronize max_turn_silence with min_turn_silence
|
||||
min_silence = connection_params.min_turn_silence
|
||||
if min_silence is None:
|
||||
min_silence = 100
|
||||
|
||||
# Check max_turn_silence
|
||||
if connection_params.max_turn_silence is None:
|
||||
updates["max_turn_silence"] = 2000
|
||||
elif connection_params.max_turn_silence < 1000:
|
||||
logger.warning(
|
||||
f"vad_force_turn_endpoint is enabled but max_turn_silence is set to "
|
||||
f"{connection_params.max_turn_silence}ms. With manual turn detection, "
|
||||
f"a higher value (e.g., 2000ms) is recommended to avoid premature "
|
||||
f"turn endings. The current value will be used."
|
||||
)
|
||||
# Warn if user set max_turn_silence (will be overridden)
|
||||
if connection_params.max_turn_silence is not None:
|
||||
logger.warning(
|
||||
f"Your max_turn_silence value ({connection_params.max_turn_silence}ms) will be "
|
||||
f"OVERRIDDEN in Pipecat mode (vad_force_turn_endpoint=True). It will be set to "
|
||||
f"{min_silence}ms (matching min_turn_silence) and SENT to "
|
||||
f"AssemblyAI to avoid double turn detection. To use your max_turn_silence as-is, "
|
||||
f"switch to AssemblyAI turn detection mode (vad_force_turn_endpoint=False)."
|
||||
)
|
||||
|
||||
updates = {
|
||||
"min_turn_silence": min_silence,
|
||||
"max_turn_silence": min_silence,
|
||||
}
|
||||
else:
|
||||
# universal-streaming: Different configuration (works differently)
|
||||
updates = {
|
||||
"end_of_turn_confidence_threshold": 1.0,
|
||||
"min_turn_silence": 160,
|
||||
}
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
@@ -190,9 +286,14 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
return True
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta.
|
||||
"""Apply a settings delta and send UpdateConfiguration if connected.
|
||||
|
||||
Settings are stored but not applied to the active connection.
|
||||
Stores settings changes and sends UpdateConfiguration message to AssemblyAI
|
||||
without reconnecting. Supports updating:
|
||||
- keyterms_prompt: List of terms to boost (can be empty array to clear)
|
||||
- prompt: Custom prompt text (u3-rt-pro only)
|
||||
- max_turn_silence: Maximum silence before forcing turn end
|
||||
- min_turn_silence: Silence before EOT check
|
||||
|
||||
Args:
|
||||
delta: A :class:`STTSettings` (or ``AssemblyAISTTSettings``) delta.
|
||||
@@ -205,18 +306,72 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# # Re-apply manual turn mode config if vad_force_turn_endpoint is active
|
||||
# # and connection_params were updated.
|
||||
# if self._vad_force_turn_endpoint and "connection_params" in changed:
|
||||
# self._settings.connection_params = self._configure_manual_turn_mode(
|
||||
# self._settings.connection_params
|
||||
# )
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
# If websocket is connected, send UpdateConfiguration for supported params
|
||||
if (
|
||||
self._websocket
|
||||
and self._websocket.state is State.OPEN
|
||||
and "connection_params" in changed
|
||||
):
|
||||
# Build UpdateConfiguration message
|
||||
update_config = {"type": "UpdateConfiguration"}
|
||||
conn_params = self._settings.connection_params
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
# Get the old connection_params to see what changed
|
||||
old_conn_params = changed.get("connection_params")
|
||||
|
||||
# Check each potentially changed parameter
|
||||
if (
|
||||
old_conn_params is None
|
||||
or conn_params.keyterms_prompt != old_conn_params.keyterms_prompt
|
||||
):
|
||||
if conn_params.keyterms_prompt is not None:
|
||||
update_config["keyterms_prompt"] = conn_params.keyterms_prompt
|
||||
logger.info(f"Updating keyterms_prompt to: {conn_params.keyterms_prompt}")
|
||||
|
||||
if old_conn_params is None or conn_params.prompt != old_conn_params.prompt:
|
||||
if conn_params.prompt is not None:
|
||||
if conn_params.speech_model != "u3-rt-pro":
|
||||
logger.warning(
|
||||
f"prompt parameter is only supported with u3-rt-pro model, "
|
||||
f"current model is {conn_params.speech_model}"
|
||||
)
|
||||
else:
|
||||
update_config["prompt"] = conn_params.prompt
|
||||
logger.info(f"Updating prompt")
|
||||
|
||||
if (
|
||||
old_conn_params is None
|
||||
or conn_params.max_turn_silence != old_conn_params.max_turn_silence
|
||||
):
|
||||
if conn_params.max_turn_silence is not None:
|
||||
update_config["max_turn_silence"] = conn_params.max_turn_silence
|
||||
logger.info(f"Updating max_turn_silence to: {conn_params.max_turn_silence}ms")
|
||||
|
||||
if (
|
||||
old_conn_params is None
|
||||
or conn_params.min_turn_silence != old_conn_params.min_turn_silence
|
||||
):
|
||||
if conn_params.min_turn_silence is not None:
|
||||
update_config["min_turn_silence"] = conn_params.min_turn_silence
|
||||
logger.info(f"Updating min_turn_silence to: {conn_params.min_turn_silence}ms")
|
||||
|
||||
# Send update if we have parameters to update
|
||||
if len(update_config) > 1: # More than just "type"
|
||||
try:
|
||||
await self._websocket.send(json.dumps(update_config))
|
||||
logger.info(f"Sent UpdateConfiguration: {update_config}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send UpdateConfiguration: {e}")
|
||||
elif "connection_params" in changed:
|
||||
logger.warning(
|
||||
"Connection params changed but WebSocket not connected. "
|
||||
"Settings will be applied on next connection."
|
||||
)
|
||||
|
||||
# Warn about other settings that can't be changed dynamically
|
||||
other_changes = {k: v for k, v in changed.items() if k not in ["connection_params"]}
|
||||
if other_changes:
|
||||
self._warn_unhandled_updated_settings(other_changes)
|
||||
|
||||
return changed
|
||||
|
||||
@@ -283,6 +438,7 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
and self._websocket
|
||||
and self._websocket.state is State.OPEN
|
||||
):
|
||||
self.request_finalize()
|
||||
await self._websocket.send(json.dumps({"type": "ForceEndpoint"}))
|
||||
await self.start_processing_metrics()
|
||||
|
||||
@@ -295,6 +451,9 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
"""Build WebSocket URL with query parameters using urllib.parse.urlencode."""
|
||||
params = {}
|
||||
for k, v in self._settings.connection_params.model_dump().items():
|
||||
# Skip deprecated parameter - it's been migrated to min_turn_silence
|
||||
if k == "min_end_of_turn_silence_when_confident":
|
||||
continue
|
||||
if v is not None:
|
||||
if k == "keyterms_prompt":
|
||||
params[k] = json.dumps(v)
|
||||
@@ -421,6 +580,9 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
# Log raw JSON for Turn messages to debug speaker_label
|
||||
if data.get("type") == "Turn":
|
||||
logger.trace(f"{self} RAW JSON from AssemblyAI: {json.dumps(data, indent=2)}")
|
||||
await self._handle_message(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
@@ -433,6 +595,8 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
return BeginMessage.model_validate(message)
|
||||
elif msg_type == "Turn":
|
||||
return TurnMessage.model_validate(message)
|
||||
elif msg_type == "SpeechStarted":
|
||||
return SpeechStartedMessage.model_validate(message)
|
||||
elif msg_type == "Termination":
|
||||
return TerminationMessage.model_validate(message)
|
||||
else:
|
||||
@@ -449,11 +613,33 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
)
|
||||
elif isinstance(parsed_message, TurnMessage):
|
||||
await self._handle_transcription(parsed_message)
|
||||
elif isinstance(parsed_message, SpeechStartedMessage):
|
||||
await self._handle_speech_started(parsed_message)
|
||||
elif isinstance(parsed_message, TerminationMessage):
|
||||
await self._handle_termination(parsed_message)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
|
||||
async def _handle_speech_started(self, message: SpeechStartedMessage):
|
||||
"""Handle SpeechStarted event — fast barge-in for AssemblyAI turn detection.
|
||||
|
||||
Broadcasts UserStartedSpeakingFrame to signal the start of user
|
||||
speech, then pushes an interruption to cancel any bot audio.
|
||||
SpeechStarted fires before any transcript arrives, so the turn
|
||||
is cleanly started before any transcription frames are pushed.
|
||||
|
||||
Only applies when using AssemblyAI's built-in turn detection. When using
|
||||
Pipecat turn detection, VAD + smart turn analyzer handle interruptions.
|
||||
"""
|
||||
if self._vad_force_turn_endpoint:
|
||||
return # Pipecat mode: handled by aggregator
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
self._user_speaking = True
|
||||
|
||||
async def _handle_termination(self, message: TerminationMessage):
|
||||
"""Handle termination message."""
|
||||
self._received_termination = True
|
||||
@@ -466,30 +652,109 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
await self.push_frame(EndFrame())
|
||||
|
||||
async def _handle_transcription(self, message: TurnMessage):
|
||||
"""Handle transcription results."""
|
||||
"""Handle transcription results with two turn detection modes.
|
||||
|
||||
Pipecat turn detection (vad_force_turn_endpoint=True):
|
||||
- No UserStarted/StoppedSpeakingFrame from STT
|
||||
- end_of_turn → TranscriptionFrame (finalized set by base class
|
||||
if this is a ForceEndpoint response)
|
||||
- else → InterimTranscriptionFrame
|
||||
|
||||
AssemblyAI turn detection (vad_force_turn_endpoint=False):
|
||||
- UserStartedSpeakingFrame on first transcript
|
||||
- end_of_turn → TranscriptionFrame + UserStoppedSpeakingFrame
|
||||
- else → InterimTranscriptionFrame
|
||||
"""
|
||||
if not message.transcript:
|
||||
return
|
||||
if message.end_of_turn and (
|
||||
not self._settings.connection_params.formatted_finals or message.turn_is_formatted
|
||||
):
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
message.transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._settings.language,
|
||||
message,
|
||||
|
||||
# Use detected language if available with sufficient confidence
|
||||
language = Language.EN
|
||||
if message.language_code and message.language_confidence:
|
||||
if message.language_confidence >= 0.7:
|
||||
language = map_language_from_assemblyai(message.language_code)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Low language detection confidence ({message.language_confidence:.2f}) "
|
||||
f"for language '{message.language_code}', falling back to English"
|
||||
)
|
||||
|
||||
# Handle speaker diarization
|
||||
speaker_id = self._user_id
|
||||
transcript_text = message.transcript
|
||||
|
||||
if message.speaker:
|
||||
speaker_id = message.speaker
|
||||
# Format transcript with speaker labels if format string provided
|
||||
if self._speaker_format:
|
||||
transcript_text = self._speaker_format.format(
|
||||
speaker=message.speaker, text=message.transcript
|
||||
)
|
||||
|
||||
# Determine if this is a final turn from AssemblyAI
|
||||
is_final_turn = message.end_of_turn and (
|
||||
not self._settings.connection_params.format_turns or message.turn_is_formatted
|
||||
)
|
||||
|
||||
if self._vad_force_turn_endpoint:
|
||||
# --- Pipecat turn detection mode ---
|
||||
# No UserStarted/StoppedSpeakingFrame — VAD + smart turn analyzer handle this
|
||||
if is_final_turn:
|
||||
finalize_confirmed = bool(message.turn_is_formatted)
|
||||
if finalize_confirmed:
|
||||
self.confirm_finalize()
|
||||
logger.debug(f'{self} Transcript: "{transcript_text}"')
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript_text,
|
||||
speaker_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
message,
|
||||
)
|
||||
)
|
||||
await self._trace_transcription(transcript_text, True, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript_text,
|
||||
speaker_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
message,
|
||||
)
|
||||
)
|
||||
)
|
||||
await self._trace_transcription(message.transcript, True, self._settings.language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
message.transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._settings.language,
|
||||
message,
|
||||
# --- AssemblyAI turn detection mode ---
|
||||
# SpeechStarted always arrives before transcripts with u3-rt-pro,
|
||||
# so UserStartedSpeakingFrame is guaranteed to be broadcast first.
|
||||
if is_final_turn:
|
||||
# AssemblyAI controls finalization, just mark as finalized
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript_text,
|
||||
speaker_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
message,
|
||||
finalized=True,
|
||||
)
|
||||
)
|
||||
await self._trace_transcription(transcript_text, True, language)
|
||||
await self.stop_processing_metrics()
|
||||
# AAI is authoritative — emit UserStoppedSpeakingFrame immediately.
|
||||
# broadcast_frame pushes downstream (same queue as TranscriptionFrame
|
||||
# above, so ordering is preserved) and upstream.
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
self._user_speaking = False
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript_text,
|
||||
speaker_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
message,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
CancellationReason,
|
||||
ResultReason,
|
||||
SpeechConfig,
|
||||
SpeechRecognizer,
|
||||
@@ -80,6 +81,7 @@ class AzureSTTService(STTService):
|
||||
region: str,
|
||||
language: Language = Language.EN_US,
|
||||
sample_rate: Optional[int] = None,
|
||||
private_endpoint: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
ttfs_p99_latency: Optional[float] = AZURE_TTFS_P99,
|
||||
**kwargs,
|
||||
@@ -91,6 +93,8 @@ class AzureSTTService(STTService):
|
||||
region: Azure region for the Speech service (e.g., 'eastus').
|
||||
language: Language for speech recognition. Defaults to English (US).
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
private_endpoint: Private endpoint for STT behind firewall.
|
||||
See https://docs.azure.cn/en-us/ai-services/speech-service/speech-services-private-link?tabs=portal
|
||||
endpoint_id: Custom model endpoint id.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
@@ -112,6 +116,7 @@ class AzureSTTService(STTService):
|
||||
subscription=api_key,
|
||||
region=region,
|
||||
speech_recognition_language=language_to_azure_language(language),
|
||||
endpoint=private_endpoint,
|
||||
)
|
||||
|
||||
if endpoint_id:
|
||||
@@ -205,6 +210,7 @@ class AzureSTTService(STTService):
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.canceled.connect(self._on_handle_canceled)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
@@ -276,3 +282,13 @@ class AzureSTTService(STTService):
|
||||
result=event,
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
def _on_handle_canceled(self, event):
|
||||
details = event.result.cancellation_details
|
||||
if details.reason == CancellationReason.Error:
|
||||
error_msg = f"Azure STT recognition canceled: {details.reason}"
|
||||
if details.error_details:
|
||||
error_msg += f" - {details.error_details}"
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.push_error(error_msg=error_msg), self.get_event_loop()
|
||||
)
|
||||
|
||||
@@ -561,9 +561,13 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
# User cancellation (from interruption) is expected, not an error
|
||||
if reason == CancellationReason.CancelledByUser:
|
||||
logger.debug(f"{self}: Speech synthesis canceled by user (interruption)")
|
||||
self._audio_queue.put_nowait(None)
|
||||
else:
|
||||
logger.warning(f"{self}: Speech synthesis canceled: {reason}")
|
||||
self._audio_queue.put_nowait(None)
|
||||
details = evt.result.cancellation_details
|
||||
error_msg = f"Azure TTS synthesis canceled: {reason}"
|
||||
if details.error_details:
|
||||
error_msg += f" - {details.error_details}"
|
||||
self._audio_queue.put_nowait(Exception(error_msg))
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
@@ -676,6 +680,9 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
chunk = await self._audio_queue.get()
|
||||
if chunk is None: # End of stream
|
||||
break
|
||||
if isinstance(chunk, Exception): # Error from _handle_canceled
|
||||
yield ErrorFrame(error=str(chunk))
|
||||
break
|
||||
|
||||
if self._first_chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -9,6 +9,7 @@ import sys
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .flux import *
|
||||
from .sagemaker import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
|
||||
@@ -675,7 +675,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
self._user_is_speaking = True
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self.start_metrics()
|
||||
await self._call_event_handler("on_start_of_turn", transcript)
|
||||
if transcript:
|
||||
|
||||
448
src/pipecat/services/deepgram/sagemaker/stt.py
Normal file
448
src/pipecat/services/deepgram/sagemaker/stt.py
Normal file
@@ -0,0 +1,448 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat STT service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time transcription with support for interim results, multiple
|
||||
languages, and various Deepgram features.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.deepgram.stt import _DeepgramSTTSettingsBase
|
||||
from pipecat.services.settings import STTSettings
|
||||
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from deepgram import LiveOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerSTTSettings(_DeepgramSTTSettingsBase):
|
||||
"""Settings for the Deepgram SageMaker STT service.
|
||||
|
||||
See ``_DeepgramSTTSettingsBase`` for full documentation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepgramSageMakerSTTService(STTService):
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech recognition using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
transcription with support for interim results, speaker diarization, and
|
||||
multiple languages.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- Deepgram SDK for LiveOptions configuration
|
||||
|
||||
Example::
|
||||
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
live_options=LiveOptions(
|
||||
model="nova-3",
|
||||
language="en",
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerSTTSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram model
|
||||
deployed (e.g., "my-deepgram-nova-3-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses value from
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions configuration. Treated as a
|
||||
delta from a set of sensible defaults — only the fields you
|
||||
set are overridden; all others keep their default values.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language=Language.EN,
|
||||
model="nova-3",
|
||||
channels=1,
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
)
|
||||
|
||||
settings = DeepgramSageMakerSTTSettings(
|
||||
model=default_options.model,
|
||||
language=default_options.language,
|
||||
live_options=default_options,
|
||||
)
|
||||
if live_options:
|
||||
settings._merge_live_options_delta(live_options)
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
settings=settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and warn about unhandled changes."""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Deepgram for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via BiDi stream callbacks).
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram query string from settings, creates the BiDi client,
|
||||
starts the streaming session, and launches background tasks for processing
|
||||
responses and sending KeepAlive messages.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram on SageMaker...")
|
||||
|
||||
live_options = LiveOptions(
|
||||
**{**self._settings.live_options.to_dict(), "sample_rate": self.sample_rate}
|
||||
)
|
||||
|
||||
# Build query string from live_options, converting booleans to strings
|
||||
query_params = {}
|
||||
for key, value in live_options.to_dict().items():
|
||||
if value is not None:
|
||||
# Convert boolean values to lowercase strings for Deepgram API
|
||||
if isinstance(value, bool):
|
||||
query_params[key] = str(value).lower()
|
||||
else:
|
||||
query_params[key] = str(value)
|
||||
|
||||
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
|
||||
|
||||
# Create BiDi client
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start the session
|
||||
await self._client.start_session()
|
||||
|
||||
# Start processing responses in the background
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
# Start keepalive task to maintain connection
|
||||
self._keepalive_task = self.create_task(self._send_keepalive())
|
||||
|
||||
logger.debug("Connected to Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a CloseStream message to Deepgram, cancels background tasks
|
||||
(KeepAlive and response processing), and closes the BiDi session.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram on SageMaker...")
|
||||
|
||||
# Send CloseStream message to Deepgram
|
||||
try:
|
||||
await self._client.send_json({"type": "CloseStream"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send CloseStream message: {e}")
|
||||
|
||||
# Cancel keepalive task
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Cancel response processing task
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
# Close the BiDi session
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_keepalive(self):
|
||||
"""Send periodic KeepAlive messages to maintain the connection.
|
||||
|
||||
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
|
||||
connection is active. This prevents the connection from timing out during
|
||||
periods of silence.
|
||||
"""
|
||||
while self._client and self._client.is_active:
|
||||
await asyncio.sleep(5)
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "KeepAlive"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send KeepAlive: {e}")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream, decodes the payload,
|
||||
parses JSON responses from Deepgram, and processes transcription results.
|
||||
Runs as a background task until the connection is closed or cancelled.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
# Check if this is a PayloadPart with bytes
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
|
||||
try:
|
||||
# Parse JSON response from Deepgram
|
||||
parsed = json.loads(response_data)
|
||||
|
||||
# Extract and process transcript if available
|
||||
if "channel" in parsed:
|
||||
await self._handle_transcript_response(parsed)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON response: {response_data}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
async def _handle_transcript_response(self, parsed: dict):
|
||||
"""Handle a transcript response from Deepgram.
|
||||
|
||||
Extracts the transcript text, determines if it's final or interim, extracts
|
||||
language information, and pushes the appropriate frame (TranscriptionFrame
|
||||
or InterimTranscriptionFrame) downstream.
|
||||
|
||||
Args:
|
||||
parsed: The parsed JSON response from Deepgram containing channel,
|
||||
alternatives, transcript, and metadata.
|
||||
"""
|
||||
alternatives = parsed.get("channel", {}).get("alternatives", [])
|
||||
if not alternatives or not alternatives[0].get("transcript"):
|
||||
return
|
||||
|
||||
transcript = alternatives[0]["transcript"]
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
if alternatives[0].get("languages"):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final:
|
||||
# Check if this response is from a finalize() call.
|
||||
# Only mark as finalized when both we requested it AND Deepgram confirms it.
|
||||
from_finalize = parsed.get("from_finalize", False)
|
||||
if from_finalize:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# Interim transcription
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability and tracing
|
||||
integration. The actual transcription processing is handled by the parent
|
||||
class and observers.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: The detected language of the transcription, if available.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Deepgram SageMaker-specific handling.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# https://developers.deepgram.com/docs/finalize
|
||||
# Mark that we're awaiting a from_finalize response
|
||||
self.request_finalize()
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
360
src/pipecat/services/deepgram/sagemaker/tts.py
Normal file
360
src/pipecat/services/deepgram/sagemaker/tts.py
Normal file
@@ -0,0 +1,360 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat TTS service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time speech synthesis with support for interruptions and
|
||||
streaming audio output.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerTTSSettings(TTSSettings):
|
||||
"""Settings for Deepgram SageMaker TTS service.
|
||||
|
||||
Parameters:
|
||||
encoding: Audio encoding format (e.g. "linear16").
|
||||
"""
|
||||
|
||||
encoding: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
class DeepgramSageMakerTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech synthesis using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
audio generation with support for interruptions via the Clear message.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- ``pipecat-ai[sagemaker]`` installed
|
||||
|
||||
Example::
|
||||
|
||||
tts = DeepgramSageMakerTTSService(
|
||||
endpoint_name="my-deepgram-tts-endpoint",
|
||||
region="us-east-2",
|
||||
voice="aura-2-helena-en",
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerTTSSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
|
||||
deployed (e.g., "my-deepgram-tts-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
settings=DeepgramSageMakerTTSSettings(
|
||||
model=voice,
|
||||
voice=voice,
|
||||
language=None,
|
||||
encoding=encoding,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._ttfb_started: bool = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._ttfb_started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram TTS query string, creates the BiDi client,
|
||||
starts the streaming session, and launches a background task for processing
|
||||
responses.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram TTS on SageMaker...")
|
||||
|
||||
query_string = (
|
||||
f"model={self._settings.voice}&encoding={self._settings.encoding}"
|
||||
f"&sample_rate={self.sample_rate}"
|
||||
)
|
||||
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/speak",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._client.start_session()
|
||||
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
logger.debug("Connected to Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a Close message to Deepgram, cancels the response processing task,
|
||||
and closes the BiDi session. Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
|
||||
|
||||
try:
|
||||
await self._client.send_json({"type": "Close"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send Close message: {e}")
|
||||
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and reconnect if necessary.
|
||||
|
||||
Since all settings are part of the SageMaker session query string,
|
||||
any setting change requires reconnecting to apply the new values.
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# Deepgram uses voice as the model, so keep them in sync for metrics
|
||||
if "voice" in changed:
|
||||
self._settings.model = self._settings.voice
|
||||
self._sync_model_name_to_metrics()
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram TTS on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream. Attempts to decode
|
||||
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
|
||||
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
|
||||
a TTSAudioRawFrame downstream.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
payload = result.value.bytes_
|
||||
|
||||
# Try to decode as JSON control message first
|
||||
try:
|
||||
response_data = payload.decode("utf-8")
|
||||
parsed = json.loads(response_data)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {parsed}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {parsed}")
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {parsed}")
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: "
|
||||
f"{parsed.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {parsed}")
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not JSON — treat as raw audio bytes
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
payload,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=self._context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("TTS response processor stopped")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._ttfb_started = False
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Clear"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Flush"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram TTS on SageMaker.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
context_id: The context ID for tracking audio frames.
|
||||
|
||||
Yields:
|
||||
Frame: TTSStartedFrame, then None (audio comes asynchronously via
|
||||
the response processor).
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -471,7 +471,7 @@ class DeepgramSTTService(STTService):
|
||||
await self._call_event_handler("on_speech_started", *args, **kwargs)
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _on_utterance_end(self, *args, **kwargs):
|
||||
await self._call_event_handler("on_utterance_end", *args, **kwargs)
|
||||
|
||||
@@ -4,445 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
"""Deprecated: use ``pipecat.services.deepgram.sagemaker.stt`` instead."""
|
||||
|
||||
This module provides a Pipecat STT service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time transcription with support for interim results, multiple
|
||||
languages, and various Deepgram features.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
warnings.warn(
|
||||
"Module `pipecat.services.deepgram.stt_sagemaker` is deprecated, "
|
||||
"use `pipecat.services.deepgram.sagemaker.stt` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.deepgram.stt import _DeepgramSTTSettingsBase
|
||||
from pipecat.services.settings import STTSettings
|
||||
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from deepgram import LiveOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerSTTSettings(_DeepgramSTTSettingsBase):
|
||||
"""Settings for the Deepgram SageMaker STT service.
|
||||
|
||||
See ``_DeepgramSTTSettingsBase`` for full documentation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepgramSageMakerSTTService(STTService):
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech recognition using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
transcription with support for interim results, speaker diarization, and
|
||||
multiple languages.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- Deepgram SDK for LiveOptions configuration
|
||||
|
||||
Example::
|
||||
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
live_options=LiveOptions(
|
||||
model="nova-3",
|
||||
language="en",
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerSTTSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram model
|
||||
deployed (e.g., "my-deepgram-nova-3-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses value from
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions configuration. Treated as a
|
||||
delta from a set of sensible defaults — only the fields you
|
||||
set are overridden; all others keep their default values.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language=Language.EN,
|
||||
model="nova-3",
|
||||
channels=1,
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
)
|
||||
|
||||
settings = DeepgramSageMakerSTTSettings(
|
||||
model=default_options.model,
|
||||
language=default_options.language,
|
||||
live_options=default_options,
|
||||
)
|
||||
if live_options:
|
||||
settings._merge_live_options_delta(live_options)
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
settings=settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and warn about unhandled changes."""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Deepgram for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via BiDi stream callbacks).
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram query string from settings, creates the BiDi client,
|
||||
starts the streaming session, and launches background tasks for processing
|
||||
responses and sending KeepAlive messages.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram on SageMaker...")
|
||||
|
||||
live_options = LiveOptions(
|
||||
**{**self._settings.live_options.to_dict(), "sample_rate": self.sample_rate}
|
||||
)
|
||||
|
||||
# Build query string from live_options, converting booleans to strings
|
||||
query_params = {}
|
||||
for key, value in live_options.to_dict().items():
|
||||
if value is not None:
|
||||
# Convert boolean values to lowercase strings for Deepgram API
|
||||
if isinstance(value, bool):
|
||||
query_params[key] = str(value).lower()
|
||||
else:
|
||||
query_params[key] = str(value)
|
||||
|
||||
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
|
||||
|
||||
# Create BiDi client
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start the session
|
||||
await self._client.start_session()
|
||||
|
||||
# Start processing responses in the background
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
# Start keepalive task to maintain connection
|
||||
self._keepalive_task = self.create_task(self._send_keepalive())
|
||||
|
||||
logger.debug("Connected to Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a CloseStream message to Deepgram, cancels background tasks
|
||||
(KeepAlive and response processing), and closes the BiDi session.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram on SageMaker...")
|
||||
|
||||
# Send CloseStream message to Deepgram
|
||||
try:
|
||||
await self._client.send_json({"type": "CloseStream"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send CloseStream message: {e}")
|
||||
|
||||
# Cancel keepalive task
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Cancel response processing task
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
# Close the BiDi session
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_keepalive(self):
|
||||
"""Send periodic KeepAlive messages to maintain the connection.
|
||||
|
||||
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
|
||||
connection is active. This prevents the connection from timing out during
|
||||
periods of silence.
|
||||
"""
|
||||
while self._client and self._client.is_active:
|
||||
await asyncio.sleep(5)
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "KeepAlive"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send KeepAlive: {e}")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream, decodes the payload,
|
||||
parses JSON responses from Deepgram, and processes transcription results.
|
||||
Runs as a background task until the connection is closed or cancelled.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
# Check if this is a PayloadPart with bytes
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
|
||||
try:
|
||||
# Parse JSON response from Deepgram
|
||||
parsed = json.loads(response_data)
|
||||
|
||||
# Extract and process transcript if available
|
||||
if "channel" in parsed:
|
||||
await self._handle_transcript_response(parsed)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON response: {response_data}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
async def _handle_transcript_response(self, parsed: dict):
|
||||
"""Handle a transcript response from Deepgram.
|
||||
|
||||
Extracts the transcript text, determines if it's final or interim, extracts
|
||||
language information, and pushes the appropriate frame (TranscriptionFrame
|
||||
or InterimTranscriptionFrame) downstream.
|
||||
|
||||
Args:
|
||||
parsed: The parsed JSON response from Deepgram containing channel,
|
||||
alternatives, transcript, and metadata.
|
||||
"""
|
||||
alternatives = parsed.get("channel", {}).get("alternatives", [])
|
||||
if not alternatives or not alternatives[0].get("transcript"):
|
||||
return
|
||||
|
||||
transcript = alternatives[0]["transcript"]
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
if alternatives[0].get("languages"):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final:
|
||||
# Check if this response is from a finalize() call.
|
||||
# Only mark as finalized when both we requested it AND Deepgram confirms it.
|
||||
from_finalize = parsed.get("from_finalize", False)
|
||||
if from_finalize:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# Interim transcription
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability and tracing
|
||||
integration. The actual transcription processing is handled by the parent
|
||||
class and observers.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: The detected language of the transcription, if available.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Deepgram SageMaker-specific handling.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# https://developers.deepgram.com/docs/finalize
|
||||
# Mark that we're awaiting a from_finalize response
|
||||
self.request_finalize()
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
from pipecat.services.deepgram.sagemaker.stt import * # noqa: E402, F401, F403
|
||||
|
||||
@@ -4,357 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
"""Deprecated: use ``pipecat.services.deepgram.sagemaker.tts`` instead."""
|
||||
|
||||
This module provides a Pipecat TTS service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time speech synthesis with support for interruptions and
|
||||
streaming audio output.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
warnings.warn(
|
||||
"Module `pipecat.services.deepgram.tts_sagemaker` is deprecated, "
|
||||
"use `pipecat.services.deepgram.sagemaker.tts` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerTTSSettings(TTSSettings):
|
||||
"""Settings for Deepgram SageMaker TTS service.
|
||||
|
||||
Parameters:
|
||||
encoding: Audio encoding format (e.g. "linear16").
|
||||
"""
|
||||
|
||||
encoding: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
class DeepgramSageMakerTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech synthesis using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
audio generation with support for interruptions via the Clear message.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- ``pipecat-ai[sagemaker]`` installed
|
||||
|
||||
Example::
|
||||
|
||||
tts = DeepgramSageMakerTTSService(
|
||||
endpoint_name="my-deepgram-tts-endpoint",
|
||||
region="us-east-2",
|
||||
voice="aura-2-helena-en",
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerTTSSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
|
||||
deployed (e.g., "my-deepgram-tts-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
settings=DeepgramSageMakerTTSSettings(
|
||||
model=voice,
|
||||
voice=voice,
|
||||
language=None,
|
||||
encoding=encoding,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._ttfb_started: bool = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._ttfb_started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram TTS query string, creates the BiDi client,
|
||||
starts the streaming session, and launches a background task for processing
|
||||
responses.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram TTS on SageMaker...")
|
||||
|
||||
query_string = (
|
||||
f"model={self._settings.voice}&encoding={self._settings.encoding}"
|
||||
f"&sample_rate={self.sample_rate}"
|
||||
)
|
||||
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/speak",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._client.start_session()
|
||||
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
logger.debug("Connected to Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a Close message to Deepgram, cancels the response processing task,
|
||||
and closes the BiDi session. Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
|
||||
|
||||
try:
|
||||
await self._client.send_json({"type": "Close"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send Close message: {e}")
|
||||
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and reconnect if necessary.
|
||||
|
||||
Since all settings are part of the SageMaker session query string,
|
||||
any setting change requires reconnecting to apply the new values.
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# Deepgram uses voice as the model, so keep them in sync for metrics
|
||||
if "voice" in changed:
|
||||
self._settings.model = self._settings.voice
|
||||
self._sync_model_name_to_metrics()
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram TTS on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream. Attempts to decode
|
||||
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
|
||||
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
|
||||
a TTSAudioRawFrame downstream.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
payload = result.value.bytes_
|
||||
|
||||
# Try to decode as JSON control message first
|
||||
try:
|
||||
response_data = payload.decode("utf-8")
|
||||
parsed = json.loads(response_data)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {parsed}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {parsed}")
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {parsed}")
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: "
|
||||
f"{parsed.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {parsed}")
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not JSON — treat as raw audio bytes
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
payload,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=self._context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("TTS response processor stopped")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._ttfb_started = False
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Clear"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Flush"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram TTS on SageMaker.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
context_id: The context ID for tracking audio frames.
|
||||
|
||||
Yields:
|
||||
Frame: TTSStartedFrame, then None (audio comes asynchronously via
|
||||
the response processor).
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
from pipecat.services.deepgram.sagemaker.tts import * # noqa: E402, F401, F403
|
||||
|
||||
@@ -861,6 +861,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
|
||||
await self._handle_transcription(text, True, language)
|
||||
|
||||
finalized = self._settings.commit_strategy == CommitStrategy.MANUAL
|
||||
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
@@ -868,6 +870,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
finalized=finalized,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -902,6 +905,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
|
||||
await self._handle_transcription(text, True, language)
|
||||
|
||||
finalized = self._settings.commit_strategy == CommitStrategy.MANUAL
|
||||
|
||||
# This message is sent after committed_transcript when include_timestamps=true.
|
||||
# It contains the full transcript data including text and word-level timestamps.
|
||||
await self.push_frame(
|
||||
@@ -911,5 +916,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
finalized=finalized,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -613,7 +613,7 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _on_speech_ended(self):
|
||||
"""Handle speech end event from Gladia.
|
||||
|
||||
@@ -1265,7 +1265,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
# combination with the context aggregator default
|
||||
# turn strategies.
|
||||
logger.debug("Gemini VAD: interrupted signal received")
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
elif message.server_content and message.server_content.model_turn:
|
||||
await self._handle_msg_model_turn(message)
|
||||
elif (
|
||||
|
||||
@@ -571,8 +571,11 @@ class GrokRealtimeLLMService(LLMService):
|
||||
elif evt.type == "response.function_call_arguments.done":
|
||||
await self._handle_evt_function_call_arguments_done(evt)
|
||||
elif evt.type == "error":
|
||||
await self._handle_evt_error(evt)
|
||||
return
|
||||
if evt.error.code == "response_cancel_not_active":
|
||||
logger.debug(f"{self} {evt.error.message}")
|
||||
else:
|
||||
await self._handle_evt_error(evt)
|
||||
return
|
||||
|
||||
async def _handle_evt_conversation_created(self, evt):
|
||||
"""Handle conversation.created event - first event after connecting."""
|
||||
@@ -731,7 +734,7 @@ class GrokRealtimeLLMService(LLMService):
|
||||
"""Handle speech started event from VAD."""
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
"""Handle speech stopped event from VAD."""
|
||||
|
||||
@@ -62,10 +62,12 @@ class HeyGenCallbacks(BaseModel):
|
||||
"""Callback handlers for HeyGen events.
|
||||
|
||||
Parameters:
|
||||
on_participant_connected: Called when a participant connects
|
||||
on_participant_disconnected: Called when a participant disconnects
|
||||
on_connected: Called when the bot connects to the LiveKit room.
|
||||
on_participant_connected: Called when a participant connects.
|
||||
on_participant_disconnected: Called when a participant disconnects.
|
||||
"""
|
||||
|
||||
on_connected: Callable[[], Awaitable[None]]
|
||||
on_participant_connected: Callable[[str], Awaitable[None]]
|
||||
on_participant_disconnected: Callable[[str], Awaitable[None]]
|
||||
|
||||
@@ -251,6 +253,7 @@ class HeyGenClient:
|
||||
logger.debug(f"HeyGenClient send_interval: {self._send_interval}")
|
||||
await self._ws_connect()
|
||||
await self._livekit_connect()
|
||||
self._call_event_callback(self._callbacks.on_connected)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the client and terminate all connections.
|
||||
|
||||
@@ -128,6 +128,7 @@ class HeyGenVideoService(AIService):
|
||||
session_request=self._session_request,
|
||||
service_type=self._service_type,
|
||||
callbacks=HeyGenCallbacks(
|
||||
on_connected=self._on_connected,
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
),
|
||||
@@ -144,6 +145,10 @@ class HeyGenVideoService(AIService):
|
||||
await self._client.cleanup()
|
||||
self._client = None
|
||||
|
||||
async def _on_connected(self):
|
||||
"""Handle bot connected to LiveKit room."""
|
||||
logger.info("HeyGen bot connected to LiveKit room")
|
||||
|
||||
async def _on_participant_connected(self, participant_id: str):
|
||||
"""Handle participant connected events."""
|
||||
logger.info(f"Participant connected {participant_id}")
|
||||
|
||||
@@ -62,6 +62,7 @@ from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
DEFAULT_SUMMARIZATION_TIMEOUT,
|
||||
LLMContextSummarizationUtil,
|
||||
)
|
||||
|
||||
@@ -436,8 +437,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
last_index = -1
|
||||
error = None
|
||||
|
||||
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
|
||||
|
||||
try:
|
||||
summary, last_index = await self._generate_summary(frame)
|
||||
summary, last_index = await asyncio.wait_for(
|
||||
self._generate_summary(frame),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self.push_error(error_msg=f"Context summarization timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
error = f"Error generating context summary: {e}"
|
||||
await self.push_error(error, exception=e)
|
||||
|
||||
@@ -618,9 +618,12 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
await self._handle_evt_function_call_arguments_done(evt)
|
||||
elif evt.type == "error":
|
||||
if not await self._maybe_handle_evt_retrieve_conversation_item_error(evt):
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
if evt.error.code == "response_cancel_not_active":
|
||||
logger.debug(f"{self} {evt.error.message}")
|
||||
else:
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
|
||||
@traced_openai_realtime(operation="llm_setup")
|
||||
async def _handle_evt_session_created(self, evt):
|
||||
@@ -836,7 +839,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -639,7 +639,7 @@ class OpenAIRealtimeSTTService(WebsocketSTTService):
|
||||
logger.debug("Server VAD: speech started")
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def _handle_speech_stopped(self, evt: dict):
|
||||
|
||||
@@ -544,9 +544,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
await self._handle_evt_audio_transcript_delta(evt)
|
||||
elif evt.type == "error":
|
||||
if not await self._maybe_handle_evt_retrieve_conversation_item_error(evt):
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
if evt.error.code == "response_cancel_not_active":
|
||||
logger.debug(f"{self} {evt.error.message}")
|
||||
else:
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
|
||||
@traced_openai_realtime(operation="llm_setup")
|
||||
async def _handle_evt_session_created(self, evt):
|
||||
@@ -706,7 +709,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -147,10 +147,10 @@ class RimeTTSService(AudioContextTTSService):
|
||||
Parameters:
|
||||
language: Language for synthesis. Defaults to English.
|
||||
segment: Text segmentation mode ("immediate", "bySentence", "never").
|
||||
speed_alpha: Speech speed multiplier.
|
||||
repetition_penalty: Token repetition penalty (arcana only).
|
||||
temperature: Sampling temperature (arcana only).
|
||||
top_p: Cumulative probability threshold (arcana only).
|
||||
speed_alpha: Speech speed multiplier (mistv2 only).
|
||||
reduce_latency: Whether to reduce latency at potential quality cost (mistv2 only).
|
||||
pause_between_brackets: Whether to add pauses between bracketed content (mistv2 only).
|
||||
phonemize_between_brackets: Whether to phonemize bracketed content (mistv2 only).
|
||||
@@ -160,12 +160,12 @@ class RimeTTSService(AudioContextTTSService):
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
segment: Optional[str] = None
|
||||
speed_alpha: Optional[float] = None
|
||||
# Arcana params
|
||||
repetition_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
# Mistv2 params
|
||||
speed_alpha: Optional[float] = None
|
||||
reduce_latency: Optional[bool] = None
|
||||
pause_between_brackets: Optional[bool] = None
|
||||
phonemize_between_brackets: Optional[bool] = None
|
||||
@@ -230,12 +230,12 @@ class RimeTTSService(AudioContextTTSService):
|
||||
else None,
|
||||
segment=params.segment,
|
||||
inlineSpeedAlpha=None, # Not applicable here
|
||||
speedAlpha=params.speed_alpha,
|
||||
# Arcana params
|
||||
repetition_penalty=params.repetition_penalty,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
# Mistv2 params
|
||||
speedAlpha=params.speed_alpha,
|
||||
reduceLatency=params.reduce_latency,
|
||||
pauseBetweenBrackets=params.pause_between_brackets,
|
||||
phonemizeBetweenBrackets=params.phonemize_between_brackets,
|
||||
@@ -301,6 +301,8 @@ class RimeTTSService(AudioContextTTSService):
|
||||
params["lang"] = self._settings.language
|
||||
if self._settings.segment is not None:
|
||||
params["segment"] = self._settings.segment
|
||||
if self._settings.speedAlpha is not None:
|
||||
params["speedAlpha"] = self._settings.speedAlpha
|
||||
|
||||
if self._settings.model == "arcana":
|
||||
if self._settings.repetition_penalty is not None:
|
||||
@@ -310,8 +312,6 @@ class RimeTTSService(AudioContextTTSService):
|
||||
if self._settings.top_p is not None:
|
||||
params["top_p"] = self._settings.top_p
|
||||
else: # mistv2/mist
|
||||
if self._settings.speedAlpha is not None:
|
||||
params["speedAlpha"] = self._settings.speedAlpha
|
||||
if self._settings.reduceLatency is not None:
|
||||
params["reduceLatency"] = self._settings.reduceLatency
|
||||
if self._settings.pauseBetweenBrackets is not None:
|
||||
|
||||
@@ -266,15 +266,10 @@ class SarvamSTTService(STTService):
|
||||
|
||||
# Initialize Sarvam SDK client
|
||||
self._sdk_headers = sdk_headers()
|
||||
# NOTE: We avoid passing non-standard kwargs here because different sarvamai
|
||||
# versions expose different constructor signatures (static type checkers
|
||||
# complain otherwise). We instead inject headers best-effort below.
|
||||
self._sarvam_client = AsyncSarvamAI(api_subscription_key=api_key)
|
||||
for attr in ("default_headers", "_default_headers", "headers", "_headers"):
|
||||
d = getattr(self._sarvam_client, attr, None)
|
||||
if isinstance(d, dict):
|
||||
d.update(self._sdk_headers)
|
||||
break
|
||||
# Pass Pipecat SDK headers directly at client construction time so they are
|
||||
# merged by the Sarvam SDK's client wrapper and consistently applied to
|
||||
# WebSocket handshake requests.
|
||||
self._sarvam_client = AsyncSarvamAI(api_subscription_key=api_key, headers=self._sdk_headers)
|
||||
self._websocket_context = None
|
||||
self._socket_client = None
|
||||
self._receive_task = None
|
||||
@@ -517,20 +512,26 @@ class SarvamSTTService(STTService):
|
||||
connect_kwargs["prompt"] = self._settings.prompt
|
||||
|
||||
def _connect_with_sdk_headers(connect_fn, **kwargs):
|
||||
# Different SDK versions may use different kwarg names.
|
||||
# If prompt is unsupported at connect-time, retry without it.
|
||||
# Headers are supplied through request_options because this is a
|
||||
# documented SDK parameter that survives SDK signature changes.
|
||||
request_options = {"additional_headers": self._sdk_headers}
|
||||
|
||||
attempts = [kwargs]
|
||||
if "prompt" in kwargs:
|
||||
attempts.append({k: v for k, v in kwargs.items() if k != "prompt"})
|
||||
|
||||
last_type_error = None
|
||||
for attempt_kwargs in attempts:
|
||||
for header_kw in ("headers", "additional_headers", "extra_headers"):
|
||||
try:
|
||||
return connect_fn(**attempt_kwargs, **{header_kw: self._sdk_headers})
|
||||
except TypeError as e:
|
||||
last_type_error = e
|
||||
try:
|
||||
return connect_fn(
|
||||
**attempt_kwargs,
|
||||
request_options=request_options,
|
||||
)
|
||||
except TypeError as e:
|
||||
last_type_error = e
|
||||
try:
|
||||
# Fallback for SDK builds that don't expose request_options.
|
||||
return connect_fn(**attempt_kwargs)
|
||||
except TypeError as e:
|
||||
last_type_error = e
|
||||
@@ -643,7 +644,7 @@ class SarvamSTTService(STTService):
|
||||
logger.debug("User started speaking")
|
||||
await self._call_event_handler("on_speech_started")
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
elif signal == "END_SPEECH":
|
||||
logger.debug("User stopped speaking")
|
||||
|
||||
@@ -1013,12 +1013,14 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
ws_additional_headers = {
|
||||
"api-subscription-key": self._api_key,
|
||||
**sdk_headers(),
|
||||
}
|
||||
|
||||
self._websocket = await websocket_connect(
|
||||
self._websocket_url,
|
||||
additional_headers={
|
||||
"api-subscription-key": self._api_key,
|
||||
**sdk_headers(),
|
||||
},
|
||||
additional_headers=ws_additional_headers,
|
||||
)
|
||||
logger.debug("Connected to Sarvam TTS Websocket")
|
||||
await self._send_config()
|
||||
|
||||
@@ -836,7 +836,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
# await self.start_processing_metrics()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_end_of_turn(self, message: dict[str, Any]) -> None:
|
||||
"""Handle EndOfTurn events.
|
||||
|
||||
@@ -94,6 +94,7 @@ class TavusVideoService(AIService):
|
||||
"""
|
||||
await super().setup(setup)
|
||||
callbacks = TavusCallbacks(
|
||||
on_joined=self._on_joined,
|
||||
on_participant_joined=self._on_participant_joined,
|
||||
on_participant_left=self._on_participant_left,
|
||||
)
|
||||
@@ -119,6 +120,10 @@ class TavusVideoService(AIService):
|
||||
await self._client.cleanup()
|
||||
self._client = None
|
||||
|
||||
async def _on_joined(self, data):
|
||||
"""Handle bot joined the Daily room."""
|
||||
logger.info("Tavus bot joined Daily room")
|
||||
|
||||
async def _on_participant_left(self, participant, reason):
|
||||
"""Handle participant leaving the session."""
|
||||
participant_id = participant["id"]
|
||||
|
||||
@@ -39,6 +39,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMAssistantPushAggregationFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartFrame,
|
||||
@@ -67,10 +68,16 @@ class TTSContext:
|
||||
"""Context information for a TTS request.
|
||||
|
||||
Attributes:
|
||||
append_to_context: Whether this TTS output should be appended to the conversation context.
|
||||
append_to_context: Whether this TTS output should be appended to the
|
||||
conversation context after it is spoken.
|
||||
push_assistant_aggregation: Whether to push an
|
||||
``LLMAssistantPushAggregationFrame`` after the TTS has finished
|
||||
speaking, forcing the assistant aggregator to commit its current
|
||||
text buffer to the conversation context.
|
||||
"""
|
||||
|
||||
append_to_context: bool = True
|
||||
push_assistant_aggregation: Optional[bool] = False
|
||||
|
||||
|
||||
class TextAggregationMode(str, Enum):
|
||||
@@ -641,10 +648,13 @@ class TTSService(AIService):
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
# If we are not receiving text from the LLM, we can assume that the SpeakFrame should be automatically added to the context
|
||||
push_assistant_aggregation = frame.append_to_context and not self._llm_response_started
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(frame.text, AggregationType.SENTENCE),
|
||||
append_tts_text_to_context=frame.append_to_context,
|
||||
push_assistant_aggregation=push_assistant_aggregation,
|
||||
)
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
@@ -809,6 +819,7 @@ class TTSService(AIService):
|
||||
src_frame: AggregatedTextFrame,
|
||||
includes_inter_frame_spaces: Optional[bool] = False,
|
||||
append_tts_text_to_context: Optional[bool] = True,
|
||||
push_assistant_aggregation: Optional[bool] = False,
|
||||
):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
@@ -876,7 +887,8 @@ class TTSService(AIService):
|
||||
self._tts_contexts[context_id] = TTSContext(
|
||||
append_to_context=append_tts_text_to_context
|
||||
if append_tts_text_to_context is not None
|
||||
else True
|
||||
else True,
|
||||
push_assistant_aggregation=push_assistant_aggregation,
|
||||
)
|
||||
|
||||
# Apply any final text preparation (e.g., trailing space)
|
||||
@@ -905,6 +917,8 @@ class TTSService(AIService):
|
||||
if append_tts_text_to_context is not None:
|
||||
frame.append_to_context = append_tts_text_to_context
|
||||
await self.push_frame(frame)
|
||||
if push_assistant_aggregation:
|
||||
await self.push_frame(LLMAssistantPushAggregationFrame())
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
has_started = False
|
||||
@@ -988,6 +1002,9 @@ class TTSService(AIService):
|
||||
frame = TTSStoppedFrame()
|
||||
frame.pts = last_pts
|
||||
frame.context_id = context_id
|
||||
if context_id in self._tts_contexts:
|
||||
if self._tts_contexts[context_id].push_assistant_aggregation:
|
||||
await self.push_frame(LLMAssistantPushAggregationFrame())
|
||||
else:
|
||||
# Assumption: word-by-word text frames don't include spaces, so
|
||||
# we can rely on the default includes_inter_frame_spaces=False
|
||||
|
||||
@@ -558,7 +558,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if should_push_immediate_interruption and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
elif self.interruption_strategies and self._bot_speaking:
|
||||
logger.debug(
|
||||
"User started speaking while bot is speaking with interruption config - "
|
||||
|
||||
@@ -24,7 +24,9 @@ from pydantic import BaseModel
|
||||
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotConnectedFrame,
|
||||
CancelFrame,
|
||||
ClientConnectedFrame,
|
||||
DataFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -2070,6 +2072,8 @@ class DailyTransport(BaseTransport):
|
||||
Event handlers available:
|
||||
|
||||
- on_joined: Called when the bot joins the room. Args: (data: dict)
|
||||
- on_connected: Called when the bot connects to the room (alias for
|
||||
on_joined). Args: (data: dict)
|
||||
- on_left: Called when the bot leaves the room.
|
||||
- on_before_leave: [sync] Called just before the bot leaves the room.
|
||||
- on_error: Called when a transport error occurs. Args: (error: str)
|
||||
@@ -2187,6 +2191,7 @@ class DailyTransport(BaseTransport):
|
||||
# Register supported handlers. The user will only be able to register
|
||||
# these handlers.
|
||||
self._register_event_handler("on_active_speaker_changed")
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_joined")
|
||||
self._register_event_handler("on_left")
|
||||
self._register_event_handler("on_error")
|
||||
@@ -2578,6 +2583,10 @@ class DailyTransport(BaseTransport):
|
||||
if error:
|
||||
await self._on_error(f"Unable to start transcription: {error}")
|
||||
await self._call_event_handler("on_joined", data)
|
||||
# Also call on_connected for compatibility with other transports
|
||||
await self._call_event_handler("on_connected", data)
|
||||
if self._input:
|
||||
await self._input.push_frame(BotConnectedFrame())
|
||||
|
||||
async def _on_left(self):
|
||||
"""Handle room left events."""
|
||||
@@ -2716,6 +2725,8 @@ class DailyTransport(BaseTransport):
|
||||
await self._call_event_handler("on_participant_joined", participant)
|
||||
# Also call on_client_connected for compatibility with other transports
|
||||
await self._call_event_handler("on_client_connected", participant)
|
||||
if self._input:
|
||||
await self._input.push_frame(ClientConnectedFrame())
|
||||
|
||||
async def _on_participant_left(self, participant, reason):
|
||||
"""Handle participant left events."""
|
||||
|
||||
@@ -23,9 +23,11 @@ from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
BotConnectedFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
ClientConnectedFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
@@ -339,6 +341,7 @@ class HeyGenTransport(BaseTransport):
|
||||
session_request=session_request,
|
||||
service_type=service_type,
|
||||
callbacks=HeyGenCallbacks(
|
||||
on_connected=self._on_connected,
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
),
|
||||
@@ -349,9 +352,16 @@ class HeyGenTransport(BaseTransport):
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
# these handlers.
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_client_connected")
|
||||
self._register_event_handler("on_client_disconnected")
|
||||
|
||||
async def _on_connected(self):
|
||||
"""Handle bot connected to LiveKit room."""
|
||||
await self._call_event_handler("on_connected")
|
||||
if self._input:
|
||||
await self._input.push_frame(BotConnectedFrame())
|
||||
|
||||
async def _on_participant_disconnected(self, participant_id: str):
|
||||
logger.debug(f"HeyGen participant {participant_id} disconnected")
|
||||
if participant_id != "heygen":
|
||||
@@ -387,6 +397,8 @@ class HeyGenTransport(BaseTransport):
|
||||
async def _on_client_connected(self, participant: Any):
|
||||
"""Handle client connected events."""
|
||||
await self._call_event_handler("on_client_connected", participant)
|
||||
if self._input:
|
||||
await self._input.push_frame(ClientConnectedFrame())
|
||||
|
||||
async def _on_client_disconnected(self, participant: Any):
|
||||
"""Handle client disconnected events."""
|
||||
|
||||
0
src/pipecat/transports/lemonslice/__init__.py
Normal file
0
src/pipecat/transports/lemonslice/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user